Source code for cil.optimisation.algorithms.Algorithm
# Copyright 2019 United Kingdom Research and Innovation# Copyright 2019 The University of Manchester## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.## Authors:# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txtfromitertoolsimportcountfromnumbersimportIntegralfromtypingimportList,Optionalfromwarningsimportwarnimportnumpyasnpfromcil.optimisation.utilities.callbacksimportCallback,LogfileCallback,_OldCallback,ProgressCallback
[docs]classAlgorithm:r"""Base class providing minimal infrastructure for iterative algorithms. An iterative algorithm is designed to solve an optimization problem by repeatedly refining a solution. In CIL, we use iterative algorithms to minimize an objective function, often referred to as a loss. The process begins with an initial guess, and with each iteration, the algorithm updates the current solution based on the results of previous iterations (previous iterates). Iterative algorithms typically continue until a stopping criterion is met, indicating that an optimal or sufficiently good solution has been found. In CIL, stopping criteria can be implemented using a callback function (`cil.optimisation.utilities.callbacks`). The user is required to implement the :code:`set_up`, :code:`__init__`, :code:`update` and :code:`update_objective` methods. The method :code:`run` is available to run :code:`n` iterations. The method accepts :code:`callbacks`: a list of callables, each of which receive the current Algorithm object (which in turn contains the iteration number and the actual objective value) and can be used to trigger print to screens and other user interactions. The :code:`run` method will stop when the stopping criterion is met or `StopIteration` is raised. Parameters ---------- update_objective_interval: int, optional, default 1 The objective (or loss) is calculated and saved every `update_objective_interval`. 1 means every iteration, 2 every 2 iterations and so forth. This is by default 1 and should be increased when evaluating the objective is computationally expensive. """def__init__(self,update_objective_interval=1,max_iteration=None,log_file=None):self.iteration=-1self.__max_iteration=1ifmax_iterationisnotNone:warn("use `Algorithm.run(iterations)` instead of `Algorithm(max_iteration)`",DeprecationWarning,stacklevel=2)self.__max_iteration=max_iterationself.__loss=[]self.memopt=Falseself.configured=Falseself._iteration=[]self.update_objective_interval=update_objective_interval# self.x = Noneself.iter_string='Iter'iflog_fileisnotNone:warn("use `run(callbacks=[LogfileCallback(log_file)])` instead of `log_file`",DeprecationWarning,stacklevel=2)self.__log_file=log_file
[docs]defset_up(self,*args,**kwargs):'''Set up the algorithm'''raiseNotImplementedError
[docs]defupdate(self):'''A single iteration of the algorithm'''raiseNotImplementedError
[docs]defshould_stop(self):'''default stopping criterion: number of iterations The user can change this in concrete implementation of iterative algorithms.'''returnself.iteration>self.max_iteration
def__set_up_logger(self,*_,**__):"""Do not use: this is being deprecated"""warn("use `run(callbacks=[LogfileCallback(log_file)])` instead",DeprecationWarning,stacklevel=2)
[docs]defmax_iteration_stop_criterion(self):"""Do not use: this is being deprecated"""warn("use `should_stop()` instead of `max_iteration_stop_criterion()`",DeprecationWarning,stacklevel=2)returnself.iteration>self.max_iteration
def__iter__(self):'''Algorithm is an iterable'''returnselfdef__next__(self):'''Algorithm is an iterable This method triggers :code:`update()` and :code:`update_objective()` '''ifself.should_stop():raiseStopIterationifself.iteration==-1andself.update_objective_interval>0:self._iteration.append(self.iteration)self.update_objective()self.iteration+=1returnself.iterationifnotself.configured:raiseValueError('Algorithm not configured correctly. Please run set_up.')self.update()self.iteration+=1self._update_previous_solution()ifself.iteration>=0andself.update_objective_interval>0and\
self.iteration%self.update_objective_interval==0:self._iteration.append(self.iteration)self.update_objective()returnself.iterationdef_update_previous_solution(self):r""" An optional but common function that can be implemented by child classes to update a stored previous solution with the current one. Best practice for memory efficiency would be to do this by the swapping of pointers: .. highlight:: python .. code-block:: python tmp = self.x_old self.x_old = self.x self.x = tmp """pass
[docs]defget_output(self):r""" Returns the current solution. Returns ------- DataContainer The current solution """returnself.x
def_provable_convergence_condition(self):r""" Checks if the algorithm set-up (e.g. chosen step-sizes or other parameters) meets a mathematical convergence criterion. Returns ------- bool: Outcome of the convergence check """raiseNotImplementedError(" Convergence criterion is not implemented for this algorithm. ")
[docs]defis_provably_convergent(self):r""" Check if the algorithm is convergent based on the provable convergence criterion. Returns ------- Boolean Outcome of the convergence check """returnself._provable_convergence_condition()
@propertydefsolution(self):" Returns the current solution. "returnself.get_output()
[docs]defget_last_loss(self,return_all=False):r'''Returns the last stored value of the loss function. "Loss" is an alias for "objective value". If `update_objective_interval` is 1 it is the value of the objective at the current iteration. If update_objective_interval > 1 it is the last stored value. Parameters ---------- return_all: Boolean, default is False If True, returns all the stored loss functions Returns ------- Float Last stored value of the loss function '''try:objective=self.__loss[-1]exceptIndexError:objective=np.nanifisinstance(objective,list):returnobjectiveifreturn_allelseobjective[0]return[objective,np.nan,np.nan]ifreturn_allelseobjective
get_last_objective=get_last_loss# alias
[docs]defupdate_objective(self):'''calculates the objective with the current solution'''raiseNotImplementedError
@propertydefiterations(self):'''returns the iterations at which the objective has been evaluated'''returnself._iteration@propertydefloss(self):'''returns a list of the values of the objective (alias of loss) during the iteration The length of this list may be shorter than the number of iterations run when the `update_objective_interval` > 1 '''returnself.__lossobjective=loss# alias@propertydefmax_iteration(self):'''gets the maximum number of iterations'''returnself.__max_iteration@max_iteration.setterdefmax_iteration(self,value):'''sets the maximum number of iterations'''assertisinstance(value,Integral)ornp.isposinf(value)self.__max_iteration=value@propertydefupdate_objective_interval(self):'''gets the update_objective_interval'''returnself.__update_objective_interval@update_objective_interval.setterdefupdate_objective_interval(self,value):'''sets the update_objective_interval'''ifnotisinstance(value,Integral)orvalue<0:raiseValueError('interval must be an integer >= 0')self.__update_objective_interval=value
[docs]defrun(self,iterations=None,callbacks:Optional[List[Callback]]=None,verbose=1,**kwargs):r"""run upto :code:`iterations` with callbacks/logging. For a demonstration of callbacks see https://github.com/TomographicImaging/CIL-Demos/blob/main/misc/callback_demonstration.ipynb Parameters ----------- iterations: int, default is None Number of iterations to run. If not set the algorithm will run until :code:`should_stop()` is reached callbacks: list of callables, default is Defaults to :code:`[ProgressCallback(verbose)]` List of callables which are passed the current Algorithm object each iteration. Defaults to :code:`[ProgressCallback(verbose)]`. verbose: 0=quiet, 1=info, 2=debug Passed to the default callback to determine the verbosity of the printed output. """if'print_interval'inkwargs:warn("use `TextProgressCallback(miniters)` instead of `run(print_interval)`",DeprecationWarning,stacklevel=2)ifcallbacksisNone:callbacks=[ProgressCallback(verbose=verbose)]# transform old-style callbacks into newcallback=kwargs.get('callback',None)ifcallbackisnotNone:callbacks.append(_OldCallback(callback,verbose=verbose))ifhasattr(self,'__log_file'):callbacks.append(LogfileCallback(self.__log_file,verbose=verbose))ifself.should_stop():print("Stop criterion has been reached.")ifiterationsisNone:warn("`run()` missing `iterations`",DeprecationWarning,stacklevel=2)iterations=self.max_iterationifself.iteration==-1andself.update_objective_interval>0:iterations+=1# call `__next__` upto `iterations` times or until `StopIteration` is raisedself.max_iteration=self.iteration+iterationsiters=(count(self.iteration)ifnp.isposinf(self.max_iteration)elserange(self.iteration,self.max_iteration))for_inzip(iters,self):try:forcallbackincallbacks:callback(self)exceptStopIteration:break
[docs]defobjective_to_dict(self,verbose=False):"""Internal function to save and print objective functions"""obj=self.get_last_objective(return_all=verbose)ifisinstance(obj,list)andlen(obj)==3:ifnotnp.isnan(obj[1:]).all():return{'primal':obj[0],'dual':obj[1],'primal_dual':obj[2]}obj=obj[0]return{'objective':obj}
[docs]defobjective_to_string(self,verbose=False):"""Do not use: this is being deprecated"""warn("consider using `run(callbacks=[LogfileCallback(log_file)])` instead",DeprecationWarning,stacklevel=2)returnstr(self.objective_to_dict(verbose=verbose))
[docs]defverbose_output(self,*_,**__):"""Do not use: this is being deprecated"""warn("use `run(callbacks=[ProgressCallback()])` instead",DeprecationWarning,stacklevel=2)
[docs]defverbose_header(self,*_,**__):"""Do not use: this is being deprecated"""warn("consider using `run(callbacks=[LogfileCallback(log_file)])` instead",DeprecationWarning,stacklevel=2)