[docs]classCallback(ABC):'''Base Callback to inherit from for use in :code:`Algorithm.run(callbacks: list[Callback])`. Parameters ---------- verbose: int, choice of 0,1,2, default 1 0=quiet, 1=info, 2=debug. '''def__init__(self,verbose=1):self.verbose=verbose@abstractmethoddef__call__(self,algorithm):pass
class_OldCallback(Callback):'''Converts an old-style :code:`def callback` to a new-style :code:`class Callback`. Parameters ---------- callback: :code:`callable(iteration, objective, x)` '''def__init__(self,callback,*args,**kwargs):super().__init__(*args,**kwargs)self.func=callbackdef__call__(self,algorithm):ifalgorithm.update_objective_interval>0andalgorithm.iteration%algorithm.update_objective_interval==0:self.func(algorithm.iteration,algorithm.get_last_objective(return_all=self.verbose>=2),algorithm.x)
[docs]classProgressCallback(Callback):''':code:`tqdm`-based progress bar. Parameters ---------- tqdm_class: default :code:`tqdm.auto.tqdm` **tqdm_kwargs: Passed to :code:`tqdm_class`. '''def__init__(self,verbose=1,tqdm_class=tqdm_auto,**tqdm_kwargs):super().__init__(verbose=verbose)self.tqdm_class=tqdm_classself.tqdm_kwargs=tqdm_kwargsself._obj_len=0# number of objective updatesdef__call__(self,algorithm):ifnothasattr(self,'pbar'):tqdm_kwargs=self.tqdm_kwargstqdm_kwargs.setdefault('total',algorithm.max_iteration)tqdm_kwargs.setdefault('disable',notself.verbose)tqdm_kwargs.setdefault('initial',max(0,algorithm.iteration))self.pbar=self.tqdm_class(**tqdm_kwargs)if(obj_len:=len(algorithm.objective))!=self._obj_len:self.pbar.set_postfix(algorithm.objective_to_dict(self.verbose>=2),refresh=False)self._obj_len=obj_lenself.pbar.update(algorithm.iteration-self.pbar.n)
class_TqdmText(tqdm_std):''':code:`tqdm`-based progress but text-only updates on separate lines. Parameters ---------- num_format: str Format spec for postfix numbers (i.e. objective values). bar_format: str Passed to :code:`tqdm`. '''def__init__(self,*args,num_format='+8.3e',bar_format="{n:>6d}/{total_fmt:<6}{rate_fmt:>9}{postfix}",**kwargs):self.num_format=num_formatsuper().__init__(*args,bar_format=bar_format,mininterval=0,maxinterval=0,position=0,**kwargs)self._instances.remove(self)# don't interfere with external progress bars@staticmethoddefstatus_printer(file):fp_flush=getattr(file,'flush',lambda:None)deffp_write(s):file.write(f"{s}\n")fp_flush()returnfp_writedefformat_num(self,n):returnf'{n:{self.num_format}}'defdisplay(self,*args,**kwargs):""" Clears :code:`postfix` if :code:`super().display()` succeeds (if display updates are more frequent than objective updates, users should not think the objective has stabilised). """if(updated:=super().display(*args,**kwargs)):self.set_postfix_str('',refresh=False)returnupdated
[docs]classTextProgressCallback(ProgressCallback):''':code:`ProgressCallback` but printed on separate lines to screen. Parameters ---------- miniters: int, default :code:`Algorithm.update_objective_interval` Number of algorithm iterations between screen prints. '''__init__=partialmethod(ProgressCallback.__init__,tqdm_class=_TqdmText)def__call__(self,algorithm):ifnothasattr(self,'pbar'):self.tqdm_kwargs['miniters']=min((self.tqdm_kwargs.get('miniters',algorithm.update_objective_interval),algorithm.update_objective_interval))returnsuper().__call__(algorithm)
[docs]classLogfileCallback(TextProgressCallback):''':code:`TextProgressCallback` but to a file instead of screen. Parameters ---------- log_file: FileDescriptorOrPath Passed to :code:`open()`. mode: str Passed to :code:`open()`. '''def__init__(self,log_file,mode='a',**kwargs):self.fd=open(log_file,mode=mode)super().__init__(file=self.fd,**kwargs)
classEarlyStoppingObjectiveValue(Callback):'''Callback that stops iterations if the change in the objective value is less than a provided threshold value. Parameters ---------- threshold: float, default 1e-6 Note ----- This callback only compares the last two calculated objective values. If `update_objective_interval` is greater than 1, the objective value is not calculated at each iteration (which is the default behaviour), only every `update_objective_interval` iterations. '''def__init__(self,threshold=1e-6):self.threshold=thresholddef__call__(self,algorithm):iflen(algorithm.loss)>=2:ifnp.abs(algorithm.loss[-1]-algorithm.loss[-2])<self.threshold:raiseStopIterationclassCGLSEarlyStopping(Callback):'''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value. It will also terminate if the algorithm begins to diverge i.e. if :math:`||x||_2> \omega`, where `omega` is set to default as 1e6. Parameters ---------- epsilon: float, default 1e-6 Usually a small number: the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` omega: float, default 1e6 Usually a large number: the algorithm will terminate if :math:`||x||_2> \omega` Note ----- This callback is implemented to replicate the automatic behaviour of CGLS in CIL versions <=24. It also replicates the behaviour of https://web.stanford.edu/group/SOL/software/cgls/. '''def__init__(self,epsilon=1e-6,omega=1e6):self.epsilon=epsilonself.omega=omegadef__call__(self,algorithm):if(algorithm.norms<=algorithm.norms0*self.epsilon):print('The norm of the residual is less than {} times the norm of the initial residual and so the algorithm is terminated'.format(self.epsilon))raiseStopIterationself.normx=algorithm.x.norm()ifalgorithm.normx>=self.omega:print('The norm of the solution is greater than {} and so the algorithm is terminated'.format(self.omega))raiseStopIteration