Source code for cil.optimisation.algorithms.CGLS
# Copyright 2019 United Kingdom Research and Innovation
# Copyright 2019 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
from cil.optimisation.algorithms import Algorithm
import numpy
import logging
import warnings
log = logging.getLogger(__name__)
[docs]
class CGLS(Algorithm):
r'''Conjugate Gradient Least Squares (CGLS) algorithm
The Conjugate Gradient Least Squares (CGLS) algorithm is commonly used for solving large systems of linear equations, due to its fast convergence.
Problem:
.. math::
\min_x || A x - b ||^2_2
Parameters
------------
operator : Operator
Linear operator for the inverse problem
initial : (optional) DataContainer in the domain of the operator, default is a DataContainer filled with zeros.
Initial guess
data : DataContainer in the range of the operator
Acquired data to reconstruct
Note
-----
Passing tolerance directly to CGLS is being deprecated. Instead we recommend using the callback functionality: https://tomographicimaging.github.io/CIL/nightly/optimisation/#callbacks and in particular the CGLSEarlyStopping callback replicated the old behaviour.
Reference
---------
https://web.stanford.edu/group/SOL/software/cgls/
'''
def __init__(self, initial=None, operator=None, data=None, **kwargs):
'''initialisation of the algorithm
'''
#We are deprecating tolerance
self.tolerance=kwargs.pop("tolerance", None)
if self.tolerance is not None:
warnings.warn( stacklevel=2, category=DeprecationWarning, message="Passing tolerance directly to CGLS is being deprecated. Instead we recommend using the callback functionality: https://tomographicimaging.github.io/CIL/nightly/optimisation/#callbacks and in particular the CGLSEarlyStopping callback replicated the old behaviour")
else:
self.tolerance = 0
super(CGLS, self).__init__(**kwargs)
if initial is None and operator is not None:
initial = operator.domain_geometry().allocate(0)
if initial is not None and operator is not None and data is not None:
self.set_up(initial=initial, operator=operator, data=data)
[docs]
def set_up(self, initial, operator, data):
r'''Initialisation of the algorithm
Parameters
------------
operator : Operator
Linear operator for the inverse problem
initial : (optional) DataContainer in the domain of the operator, default is a DataContainer filled with zeros.
Initial guess
data : DataContainer in the range of the operator
Acquired data to reconstruct
'''
log.info("%s setting up", self.__class__.__name__)
self.x = initial.copy()
self.operator = operator
self.r = data - self.operator.direct(self.x)
self.s = self.operator.adjoint(self.r)
self.p = self.s.copy()
self.q = self.operator.range_geometry().allocate()
self.norms0 = self.s.norm()
self.norms = self.s.norm()
self.gamma = self.norms0**2
self.normx = self.x.norm()
self.configured = True
log.info("%s configured", self.__class__.__name__)
[docs]
def update(self):
'''single iteration'''
self.operator.direct(self.p, out=self.q)
delta = self.q.squared_norm()
alpha = self.gamma/delta
self.x.sapyb(1, self.p, alpha, out=self.x)
#self.x += alpha * self.p
self.r.sapyb(1, self.q, -alpha, out=self.r)
#self.r -= alpha * self.q
self.operator.adjoint(self.r, out=self.s)
self.norms = self.s.norm()
self.gamma1 = self.gamma
self.gamma = self.norms**2
self.beta = self.gamma/self.gamma1
#self.p = self.s + self.beta * self.p
self.p.sapyb(self.beta, self.s, 1, out=self.p)
self.normx = self.x.norm()# TODO: Deprecated, remove when CGLS tolerance is removed
[docs]
def update_objective(self):
a = self.r.squared_norm()
if a is numpy.nan:
raise StopIteration()
self.loss.append(a)
[docs]
def should_stop(self): # TODO: Deprecated, remove when CGLS tolerance is removed
return self.flag() or super().should_stop()
[docs]
def flag(self): # TODO: Deprecated, remove when CGLS tolerance is removed
'''returns whether the tolerance has been reached'''
flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1)
if flag:
self.update_objective()
print('Tolerance is reached: {}'.format(self.tolerance))
return flag