from abc import ABC, abstractmethod
from functools import partialmethod
from import tqdm as tqdm_auto
from tqdm.std import tqdm as tqdm_std
import numpy as np
class Callback(ABC):
'''Base Callback to inherit from for use in :code:` list[Callback])`.
verbose: int, choice of 0,1,2, default 1
0=quiet, 1=info, 2=debug.
def __init__(self, verbose=1):
self.verbose = verbose
def __call__(self, algorithm):
class _OldCallback(Callback):
'''Converts an old-style :code:`def callback` to a new-style :code:`class Callback`.
callback: :code:`callable(iteration, objective, x)`
def __init__(self, callback, *args, **kwargs):
super().__init__(*args, **kwargs)
self.func = callback
def __call__(self, algorithm):
if algorithm.update_objective_interval > 0 and algorithm.iteration % algorithm.update_objective_interval == 0:
self.func(algorithm.iteration, algorithm.get_last_objective(return_all=self.verbose>=2), algorithm.x)
class ProgressCallback(Callback):
''':code:`tqdm`-based progress bar.
tqdm_class: default :code:``
Passed to :code:`tqdm_class`.
def __init__(self, verbose=1, tqdm_class=tqdm_auto, **tqdm_kwargs):
self.tqdm_class = tqdm_class
self.tqdm_kwargs = tqdm_kwargs
self._obj_len = 0 # number of objective updates
def __call__(self, algorithm):
if not hasattr(self, 'pbar'):
tqdm_kwargs = self.tqdm_kwargs
tqdm_kwargs.setdefault('total', algorithm.max_iteration)
tqdm_kwargs.setdefault('disable', not self.verbose)
tqdm_kwargs.setdefault('initial', max(0, algorithm.iteration))
self.pbar = self.tqdm_class(**tqdm_kwargs)
if (obj_len := len(algorithm.objective)) != self._obj_len:
self.pbar.set_postfix(algorithm.objective_to_dict(self.verbose>=2), refresh=False)
self._obj_len = obj_len
self.pbar.update(algorithm.iteration - self.pbar.n)
class _TqdmText(tqdm_std):
''':code:`tqdm`-based progress but text-only updates on separate lines.
num_format: str
Format spec for postfix numbers (i.e. objective values).
bar_format: str
Passed to :code:`tqdm`.
def __init__(self, *args, num_format='+8.3e', bar_format="{n:>6d}/{total_fmt:<6} {rate_fmt:>9}{postfix}", **kwargs):
self.num_format = num_format
super().__init__(*args, bar_format=bar_format, mininterval=0, maxinterval=0, position=0, **kwargs)
self._instances.remove(self) # don't interfere with external progress bars
def status_printer(file):
fp_flush = getattr(file, 'flush', lambda: None)
def fp_write(s):
return fp_write
def format_num(self, n):
return f'{n:{self.num_format}}'
def display(self, *args, **kwargs):
Clears :code:`postfix` if :code:`super().display()` succeeds
(if display updates are more frequent than objective updates, users should not think the objective has stabilised).
if (updated := super().display(*args, **kwargs)):
self.set_postfix_str('', refresh=False)
return updated
class TextProgressCallback(ProgressCallback):
''':code:`ProgressCallback` but printed on separate lines to screen.
miniters: int, default :code:`Algorithm.update_objective_interval`
Number of algorithm iterations between screen prints.
__init__ = partialmethod(ProgressCallback.__init__, tqdm_class=_TqdmText)
def __call__(self, algorithm):
if not hasattr(self, 'pbar'):
self.tqdm_kwargs['miniters'] = min((
self.tqdm_kwargs.get('miniters', algorithm.update_objective_interval),
return super().__call__(algorithm)
class LogfileCallback(TextProgressCallback):
''':code:`TextProgressCallback` but to a file instead of screen.
log_file: FileDescriptorOrPath
Passed to :code:`open()`.
mode: str
Passed to :code:`open()`.
def __init__(self, log_file, mode='a', **kwargs):
self.fd = open(log_file, mode=mode)
super().__init__(file=self.fd, **kwargs)
class EarlyStoppingObjectiveValue(Callback):
'''Callback that stops iterations if the change in the objective value is less than a provided threshold value.
threshold: float, default 1e-6
This callback only compares the last two calculated objective values. If `update_objective_interval` is greater than 1, the objective value is not calculated at each iteration (which is the default behaviour), only every `update_objective_interval` iterations.
def __init__(self, threshold=1e-6):
def __call__(self, algorithm):
if len(algorithm.loss)>=2:
if np.abs(algorithm.loss[-1]-algorithm.loss[-2])<self.threshold:
raise StopIteration
class CGLSEarlyStopping(Callback):
'''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value.
It will also terminate if the algorithm begins to diverge i.e. if :math:`||x||_2> \omega`, where `omega` is set to default as 1e6.
epsilon: float, default 1e-6
Usually a small number: the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2`
omega: float, default 1e6
Usually a large number: the algorithm will terminate if :math:`||x||_2> \omega`
This callback is implemented to replicate the automatic behaviour of CGLS in CIL versions <=24. It also replicates the behaviour of
def __init__(self, epsilon=1e-6, omega=1e6):
def __call__(self, algorithm):
if (algorithm.norms <= algorithm.norms0 * self.epsilon):
print('The norm of the residual is less than {} times the norm of the initial residual and so the algorithm is terminated'.format(self.epsilon))
raise StopIteration
self.normx = algorithm.x.norm()
if algorithm.normx >=
print('The norm of the solution is greater than {} and so the algorithm is terminated'.format(
raise StopIteration