Source code for cil.optimisation.functions.Function
# Copyright 2019 United Kingdom Research and Innovation# Copyright 2019 The University of Manchester## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.## Authors:# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txtimportwarningsfromnumbersimportNumberimportnumpyasnpfromfunctoolsimportreducefromcil.utilities.errorsimportInPlaceError
[docs]classFunction(object):r""" Abstract class representing a function Parameters ---------- L: number, positive, default None Lipschitz constant of the gradient of the function F(x), when it is differentiable. Note ----- The Lipschitz of the gradient of the function is a positive real number, such that :math:`\|f'(x) - f'(y)\| \leq L \|x-y\|`, assuming :math:`f: IG \rightarrow \mathbb{R}` """def__init__(self,L=None):# overrides the type check to allow None as initial valueself._L=Ldef__call__(self,x):raiseNotImplementedError
[docs]defgradient(self,x,out=None):r"""Returns the value of the gradient of function :math:`F` evaluated at :math:`x`, if it is differentiable .. math:: F'(x) Parameters ---------- x : DataContainer out: return DataContainer, if None a new DataContainer is returned, default None. Returns -------- DataContainer, the value of the gradient of the function at x. """raiseNotImplementedError
[docs]defproximal(self,x,tau,out=None):r"""Returns the proximal operator of function :math:`\tau F` evaluated at x .. math:: \text{prox}_{\tau F}(x) = \underset{z}{\text{argmin}} \frac{1}{2}\|z - x\|^{2} + \tau F(z) Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the proximal operator of the function at x with scalar :math:`\tau`. """raiseNotImplementedError
[docs]defconvex_conjugate(self,x):r""" Evaluation of the function F* at x, where F* is the convex conjugate of function F, .. math:: F^{*}(x^{*}) = \underset{x}{\sup} \langle x^{*}, x \rangle - F(x) Parameters ---------- x : DataContainer Returns ------- The value of the convex conjugate of the function at x. """raiseNotImplementedError
[docs]defproximal_conjugate(self,x,tau,out=None):r"""Returns the proximal operator of the convex conjugate of function :math:`\tau F` evaluated at :math:`x^{*}` .. math:: \text{prox}_{\tau F^{*}}(x^{*}) = \underset{z^{*}}{\text{argmin}} \frac{1}{2}\|z^{*} - x^{*}\|^{2} + \tau F^{*}(z^{*}) Due to Moreau’s identity, we have an analytic formula to compute the proximal operator of the convex conjugate :math:`F^{*}` .. math:: \text{prox}_{\tau F^{*}}(x) = x - \tau\text{prox}_{\tau^{-1} F}(\tau^{-1}x) Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the value of the proximal operator of the convex conjugate at point :math:`x` for scalar :math:`\tau` or None if `out`. """ifid(x)==id(out):raiseInPlaceError(message="The proximal_conjugate of a CIL function cannot be used in place")try:tmp=xx.divide(tau,out=tmp)exceptTypeError:tmp=x.divide(tau,dtype=np.float32)val=self.proximal(tmp,1.0/tau,out=out)ifid(tmp)==id(x):x.multiply(tau,out=x)val.sapyb(-tau,x,1.0,out=val)returnval
# Algebra for Function Class# Add functions# Subtract functions# Add/Substract with Scalar# Multiply with Scalardef__add__(self,other):""" Returns the sum of the functions. Cases: a) the sum of two functions :math:`(F_{1}+F_{2})(x) = F_{1}(x) + F_{2}(x)` b) the sum of a function with a scalar :math:`(F_{1}+scalar)(x) = F_{1}(x) + scalar` """ifisinstance(other,Number):returnSumScalarFunction(self,other)returnSumFunction(self,other)def__radd__(self,other):""" Making addition commutative. """returnself+otherdef__sub__(self,other):""" Returns the subtraction of the functions."""returnself+(-1)*otherdef__rmul__(self,scalar):"""Returns a function multiplied by a scalar."""returnScaledFunction(self,scalar)def__mul__(self,scalar):returnself.__rmul__(scalar)def__neg__(self):""" Return the negative of the function """return-1*self
[docs]defcentered_at(self,center):""" Returns a translated function, namely if we have a function :math:`F(x)` the center is at the origin. TranslateFunction is :math:`F(x - b)` and the center is at point b. Parameters ---------- center: DataContainer The point to center the function at. Returns ------- The translated function. """ifcenterisNone:returnselfelse:returnTranslateFunction(self,center)
@propertydefL(self):r'''Lipschitz of the gradient of function f. L is positive real number, such that :math:`\|f'(x) - f'(y)\| \leq L\|x-y\|`, assuming :math:`f: IG \rightarrow \mathbb{R}`'''returnself._L# return self._L@L.setterdefL(self,value):'''Setter for Lipschitz constant'''ifisinstance(value,(Number,))andvalue>=0:self._L=valueelse:raiseTypeError('The Lipschitz constant is a real positive number')
[docs]classSumFunction(Function):r"""SumFunction represents the sum of :math:`n\geq2` functions .. math:: (F_{1} + F_{2} + ... + F_{n})(\cdot) = F_{1}(\cdot) + F_{2}(\cdot) + ... + F_{n}(\cdot) Parameters ---------- *functions : Functions Functions to set up a :class:`.SumFunction` Raises ------ ValueError If the number of function is strictly less than 2. Examples -------- .. math:: F(x) = \|x\|^{2} + \frac{1}{2}\|x - 1\|^{2} >>> from cil.optimisation.functions import L2NormSquared >>> from cil.framework import ImageGeometry >>> f1 = L2NormSquared() >>> f2 = 0.5 * L2NormSquared(b = ig.allocate(1)) >>> F = SumFunction(f1, f2) .. math:: F(x) = \sum_{i=1}^{50} \|x - i\|^{2} >>> F = SumFunction(*[L2NormSquared(b=i) for i in range(50)]) """def__init__(self,*functions):super(SumFunction,self).__init__()ifnotlen(functions):raiseIndexError('At least 1 function needed')self.functions=functions@propertydefL(self):"""Returns the Lipschitz constant for the SumFunction .. math:: L = \sum_{i} L_{i} where :math:`L_{i}` is the Lipschitz constant of the smooth function :math:`F_{i}`. """L=0.forfinself.functions:iff.LisnotNone:L+=f.Lelse:L=Nonebreakself._L=Lreturnself._L@L.setterdefL(self,value):# call base class settersuper(SumFunction,self.__class__).L.fset(self,value)@propertydefLmax(self):"""Returns the maximum Lipschitz constant for the SumFunction .. math:: L = \max_{i}\{L_{i}\} where :math:`L_{i}` is the Lipschitz constant of the smooth function :math:`F_{i}`. """l=[]forfinself.functions:iff.LisnotNone:l.append(f.L)else:l=Nonebreakself._Lmax=max(l)returnself._Lmax@Lmax.setterdefLmax(self,value):# call base class settersuper(SumFunction,self.__class__).Lmax.fset(self,value)def__call__(self,x):r"""Returns the value of the sum of functions evaluated at :math:`x`. .. math:: (F_{1} + F_{2} + ... + F_{n})(x) = F_{1}(x) + F_{2}(x) + ... + F_{n}(x) """ret=0.forfinself.functions:ret+=f(x)returnret
[docs]defgradient(self,x,out=None):r"""Returns the value of the sum of the gradient of functions evaluated at :math:`x`, if all of them are differentiable. .. math:: (F'_{1} + F'_{2} + ... + F'_{n})(x) = F'_{1}(x) + F'_{2}(x) + ... + F'_{n}(x) Parameters ---------- x : DataContainer Point to evaluate the gradient at. out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the value of the sum of the gradients evaluated at point :math:`x`. """ifoutisnotNoneandid(x)==id(out):raiseInPlaceErrorfori,finenumerate(self.functions):ifi==0:ret=f.gradient(x,out=out)else:ret+=f.gradient(x)returnret
def__add__(self,other):""" Addition for the SumFunction. * :code:`SumFunction` + :code:`SumFunction` is a :code:`SumFunction`. * :code:`SumFunction` + :code:`Function` is a :code:`SumFunction`. """ifisinstance(other,SumFunction):functions=list(self.functions)+list(other.functions)returnSumFunction(*functions)elifisinstance(other,Function):functions=list(self.functions)functions.append(other)returnSumFunction(*functions)else:returnsuper(SumFunction,self).__add__(other)@propertydefnum_functions(self):returnlen(self.functions)
[docs]classScaledFunction(Function):r""" ScaledFunction represents the scalar multiplication with a Function. Let a function F then and a scalar :math:`\alpha`. If :math:`G(x) = \alpha F(x)` then: 1. :math:`G(x) = \alpha F(x)` ( __call__ method ) 2. :math:`G'(x) = \alpha F'(x)` ( gradient method ) 3. :math:`G^{*}(x^{*}) = \alpha F^{*}(\frac{x^{*}}{\alpha})` ( convex_conjugate method ) 4. :math:`\text{prox}_{\tau G}(x) = \text{prox}_{(\tau\alpha) F}(x)` ( proximal method ) """def__init__(self,function,scalar):super(ScaledFunction,self).__init__()ifnotisinstance(scalar,Number):raiseTypeError('expected scalar: got {}'.format(type(scalar)))self.scalar=scalarself.function=function@propertydefL(self):ifself._LisNone:ifself.function.LisnotNone:self._L=abs(self.scalar)*self.function.Lelse:self._L=Nonereturnself._L@L.setterdefL(self,value):# call base class settersuper(ScaledFunction,self.__class__).L.fset(self,value)@propertydefscalar(self):returnself._scalar@scalar.setterdefscalar(self,value):ifisinstance(value,(Number,)):self._scalar=valueelse:raiseTypeError('Expecting scalar type as a number type. Got {}'.format(type(value)))def__call__(self,x):r"""Returns the value of the scaled function evaluated at :math:`x`. .. math:: G(x) = \alpha F(x) Parameters ---------- x : DataContainer Returns -------- DataContainer, the value of the scaled function. """returnself.scalar*self.function(x)
[docs]defconvex_conjugate(self,x):r"""Returns the convex conjugate of the scaled function. .. math:: G^{*}(x^{*}) = \alpha F^{*}(\frac{x^{*}}{\alpha}) Parameters ---------- x : DataContainer Returns ------- The value of the convex conjugate of the scaled function. """try:x.divide(self.scalar,out=x)tmp=xexceptTypeError:tmp=x.divide(self.scalar,dtype=np.float32)val=self.function.convex_conjugate(tmp)ifid(tmp)==id(x):x.multiply(self.scalar,out=x)returnself.scalar*val
[docs]defgradient(self,x,out=None):r"""Returns the gradient of the scaled function evaluated at :math:`x`. .. math:: G'(x) = \alpha F'(x) Parameters ---------- x : DataContainer Point to evaluate the gradient at. out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the value of the gradient of the scaled function evaluated at :math:`x`. """res=self.function.gradient(x,out=out)res*=self.scalarreturnres
[docs]defproximal(self,x,tau,out=None):r"""Returns the proximal operator of the scaled function, evaluated at :math:`x`. .. math:: \text{prox}_{\tau G}(x) = \text{prox}_{(\tau\alpha) F}(x) Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the proximal operator of the scaled function evaluated at :math:`x` with scalar :math:`\tau`. """returnself.function.proximal(x,tau*self.scalar,out=out)
[docs]defproximal_conjugate(self,x,tau,out=None):r"""This returns the proximal conjugate operator for the function at :math:`x`, :math:`\tau` Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the proximal conjugate operator for the function evaluated at :math:`x` and :math:`\tau`. """ifoutisnotNoneandid(x)==id(out):raiseInPlaceErrortry:tmp=xx.divide(tau,out=tmp)exceptTypeError:tmp=x.divide(tau,dtype=np.float32)val=self.function.proximal(tmp,self.scalar/tau,out=out)ifid(tmp)==id(x):x.multiply(tau,out=x)val.sapyb(-tau,x,1.0,out=val)returnval
[docs]classSumScalarFunction(SumFunction):""" SumScalarFunction represents the sum a function with a scalar. .. math:: (F + scalar)(x) = F(x) + scalar Although SumFunction has no general expressions for i) convex_conjugate ii) proximal iii) proximal_conjugate if the second argument is a ConstantFunction then we can derive the above analytically. """def__init__(self,function,constant):super(SumScalarFunction,self).__init__(function,ConstantFunction(constant))self.constant=constantself.function=function
[docs]defconvex_conjugate(self,x):r""" Returns the convex conjugate of a :math:`(F+scalar)`, evaluated at :math:`x`. .. math:: (F+scalar)^{*}(x^{*}) = F^{*}(x^{*}) - scalar Parameters ---------- x : DataContainer Returns ------- The value of the convex conjugate evaluated at :math:`x`. """returnself.function.convex_conjugate(x)-self.constant
[docs]defproximal(self,x,tau,out=None):""" Returns the proximal operator of :math:`F+scalar` .. math:: \text{prox}_{\tau (F+scalar)}(x) = \text{prox}_{\tau F} Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the evaluation of the proximal operator evaluated at :math:`x` and :math:`\tau`. """returnself.function.proximal(x,tau,out=out)
@propertydefL(self):ifself._LisNone:ifself.function.LisnotNone:self._L=self.function.Lelse:self._L=Nonereturnself._L@L.setterdefL(self,value):# call base class settersuper(SumScalarFunction,self.__class__).L.fset(self,value)
[docs]classConstantFunction(Function):r""" ConstantFunction: :math:`F(x) = constant, constant\in\mathbb{R}` """def__init__(self,constant=0):self.constant=constantsuper(ConstantFunction,self).__init__(L=1)def__call__(self,x):""" Returns the value of the function, :math:`F(x) = constant`"""returnself.constant
[docs]defgradient(self,x,out=None):""" Returns the value of the gradient of the function, :math:`F'(x)=0` Parameters ---------- x : DataContainer Point to evaluate the gradient at. out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- A DataContainer of zeros, the same size as :math:`x`. """ifoutisNone:returnx*0.else:out.fill(0)returnout
[docs]defconvex_conjugate(self,x):r""" The convex conjugate of constant function :math:`F(x) = c\in\mathbb{R}` is .. math:: F(x^{*}) = \begin{cases} -c, & if x^{*} = 0\\ \infty, & \mbox{otherwise} \end{cases} However, :math:`x^{*} = 0` only in the limit of iterations, so in fact this can be infinity. We do not want to have inf values in the convex conjugate, so we have to penalise this value accordingly. The following penalisation is useful in the PDHG algorithm, when we compute primal & dual objectives for convergence purposes. .. math:: F^{*}(x^{*}) = \sum \max\{x^{*}, 0\} Parameters ---------- x : DataContainer Returns ------- The maximum of x and 0, summed over the entries of x. """returnx.maximum(0).sum()
[docs]defproximal(self,x,tau,out=None):r"""Returns the proximal operator of the constant function, which is the same element, i.e., .. math:: \text{prox}_{\tau F}(x) = x Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, equal to :math:`x`. """ifoutisNone:returnx.copy()else:out.fill(x)returnout
@propertydefconstant(self):returnself._constant@constant.setterdefconstant(self,value):ifnotisinstance(value,Number):raiseTypeError('expected scalar: got {}'.format(type(value)))self._constant=value@propertydefL(self):return1.def__rmul__(self,other):'''defines the right multiplication with a number'''ifnotisinstance(other,Number):raiseNotImplementedconstant=self.constant*otherreturnConstantFunction(constant)
[docs]classZeroFunction(ConstantFunction):""" ZeroFunction represents the zero function, :math:`F(x) = 0` """def__init__(self):super(ZeroFunction,self).__init__(constant=0.)
[docs]classTranslateFunction(Function):r""" TranslateFunction represents the translation of function F with respect to the center b. Let a function F and consider :math:`G(x) = F(x - center)`. Function F is centered at 0, whereas G is centered at point b. If :math:`G(x) = F(x - b)` then: 1. :math:`G(x) = F(x - b)` ( __call__ method ) 2. :math:`G'(x) = F'(x - b)` ( gradient method ) 3. :math:`G^{*}(x^{*}) = F^{*}(x^{*}) + <x^{*}, b >` ( convex_conjugate method ) 4. :math:`\text{prox}_{\tau G}(x) = \text{prox}_{\tau F}(x - b) + b` ( proximal method ) """def__init__(self,function,center):try:L=function.LexceptNotImplementedErrorasnie:L=Nonesuper(TranslateFunction,self).__init__(L=L)self.function=functionself.center=centerdef__call__(self,x):r"""Returns the value of the translated function. .. math:: G(x) = F(x - b) Parameters ---------- x : DataContainer Returns ------- The value of the translated function evaluated at :math:`x`. """try:x.subtract(self.center,out=x)tmp=xexceptTypeError:tmp=x.subtract(self.center,dtype=np.float32)val=self.function(tmp)ifid(tmp)==id(x):x.add(self.center,out=x)returnval
[docs]defgradient(self,x,out=None):r"""Returns the gradient of the translated function. .. math:: G'(x) = F'(x - b) Parameters ---------- x : DataContainer Point to evaluate the gradient at. out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the gradient of the translated function evaluated at :math:`x`. """ifid(x)==id(out):raiseInPlaceErrortry:x.subtract(self.center,out=x)tmp=xexceptTypeError:tmp=x.subtract(self.center,dtype=np.float32)val=self.function.gradient(tmp,out=out)ifid(tmp)==id(x):x.add(self.center,out=x)returnval
[docs]defproximal(self,x,tau,out=None):r"""Returns the proximal operator of the translated function. .. math:: \text{prox}_{\tau G}(x) = \text{prox}_{\tau F}(x-b) + b Parameters ---------- x : DataContainer tau: scalar out: return DataContainer, if None a new DataContainer is returned, default None. Returns ------- DataContainer, the proximal operator of the translated function at :math:`x` and :math:`\tau`. """ifid(x)==id(out):raiseInPlaceErrortry:x.subtract(self.center,out=x)tmp=xexceptTypeError:tmp=x.subtract(self.center,dtype=np.float32)val=self.function.proximal(tmp,tau,out=out)val.add(self.center,out=val)ifid(tmp)==id(x):x.add(self.center,out=x)returnval
[docs]defconvex_conjugate(self,x):r"""Returns the convex conjugate of the translated function. .. math:: G^{*}(x^{*}) = F^{*}(x^{*}) + <x^{*}, b > Parameters ---------- x : DataContainer Returns ------- The value of the convex conjugate of the translated function at :math:`x`. """returnself.function.convex_conjugate(x)+self.center.dot(x)