Source code for cil.optimisation.functions.LeastSquares

#  Copyright 2018 United Kingdom Research and Innovation
#  Copyright 2018 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt

from cil.optimisation.operators import LinearOperator, DiagonalOperator
from cil.optimisation.functions import Function
from cil.framework import DataContainer
import warnings
from numbers import Number
import numpy as np


[docs]class LeastSquares(Function): r""" (Weighted) Least Squares function .. math:: F(x) = c\|Ax-b\|_2^2 or if weighted .. math:: F(x) = c\|Ax-b\|_{2,W}^{2} where :math:`W=\text{diag}(weight)`. Parameters ----------- A : LinearOperator b : Data, DataContainer c : Scaling Constant, float, default 1.0 weight: DataContainer with all positive elements of size of the range of operator A, default None Note -------- L is the Lipshitz Constant of the gradient of :math:`F` which is :math:`2 c ||A||_2^2 = 2 c \sigma_1(A)^2`, or :math:`2 c ||W|| ||A||_2^2 = 2c||W|| \sigma_1(A)^2`, where :math:`\sigma_1(A)` is the largest singular value of :math:`A` and :math:`W=\text{diag}(weight)`. """ def __init__(self, A, b, c=1.0, weight = None): super(LeastSquares, self).__init__() self.A = A # Should be a LinearOperator self.b = b self.c = c # Default 1. # weight self.weight = weight self._weight_norm = None if weight is not None: if (self.weight<0).any(): raise ValueError('Weight contains negative values') def __call__(self, x): r""" Returns the value of :math:`F(x) = c\|Ax-b\|_2^2` or :math:`c\|Ax-b\|_{2,W}^2`, where :math:`W=\text{diag}(weight)`: """ # c * (A.direct(x)-b).dot((A.direct(x) - b)) y = self.A.direct(x) y.subtract(self.b, out = y) if self.weight is None: return self.c * y.dot(y) else: wy = self.weight.multiply(y) return self.c * y.dot(wy)
[docs] def gradient(self, x, out=None): r""" Returns the value of the gradient of :math:`F(x)`: .. math:: F'(x) = 2cA^T(Ax-b) or .. math:: F'(x) = 2cA^T(W(Ax-b)) where :math:`W=\text{diag}(weight)`. """ should_return = True if out is not None: should_return = False else: out = x * 0.0 tmp = self.A.direct(x) tmp.subtract(self.b , out=tmp) if self.weight is not None: tmp.multiply(self.weight, out=tmp) self.A.adjoint(tmp, out = out) out.multiply(self.c * 2.0, out=out) if should_return: return out
@property def L(self): if self._L is None: self.calculate_Lipschitz() return self._L @L.setter def L(self, value): warnings.warn("You should set the Lipschitz constant with calculate_Lipschitz().") if isinstance(value, (Number,)) and value >= 0: self._L = value else: raise TypeError('The Lipschitz constant is a real positive number') def calculate_Lipschitz(self): # Compute the Lipschitz parameter from the operator if possible # Leave it initialised to None otherwise try: self._L = 2.0 * np.abs(self.c) * (self.A.norm()**2) except AttributeError as ae: if self.A.is_linear(): Anorm = LinearOperator.PowerMethod(self.A, 10)[0] self._L = 2.0 * np.abs(self.c) * (Anorm*Anorm) else: warnings.warn('{} could not calculate Lipschitz Constant. {}'.format( self.__class__.__name__, ae)) except NotImplementedError as noe: warnings.warn('{} could not calculate Lipschitz Constant. {}'.format( self.__class__.__name__, noe)) if self.weight is not None: self._L *= self.weight_norm @property def weight_norm(self): if self.weight is not None: if self._weight_norm is None: D = DiagonalOperator(self.weight) self._weight_norm = D.norm() else: self._weight_norm = 1.0 return self._weight_norm def __rmul__(self, other): '''defines the right multiplication with a number''' if not isinstance (other, Number): raise NotImplemented constant = self.c * other return LeastSquares(A=self.A, b=self.b, c=constant, weight=self.weight)