Source code for cil.optimisation.functions.Function

# -*- coding: utf-8 -*-
#  Copyright 2019 United Kingdom Research and Innovation
#  Copyright 2019 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt

import warnings

from numbers import Number
import numpy as np
from functools import reduce

[docs]class Function(object): """ Abstract class representing a function :param L: Lipschitz constant of the gradient of the function F(x), when it is differentiable. :type L: number, positive, default None :param domain: The domain of the function. Lipschitz of the gradient of the function; it is a positive real number, such that |f'(x) - f'(y)| <= L ||x-y||, assuming f: IG --> R """
[docs] def __init__(self, L = None): # overrides the type check to allow None as initial value self._L = L
[docs] def __call__(self,x): r"""Returns the value of the function F at x: :math:`F(x)` """ raise NotImplementedError
[docs] def gradient(self, x, out=None): r"""Returns the value of the gradient of function F at x, if it is differentiable .. math:: F'(x) """ raise NotImplementedError
[docs] def proximal(self, x, tau, out=None): r"""Returns the proximal operator of function :math:`\tau F` at x .. math:: \mathrm{prox}_{\tau F}(x) = \underset{z}{\mathrm{argmin}} \frac{1}{2}\|z - x\|^{2} + \tau F(z) """ raise NotImplementedError
[docs] def convex_conjugate(self, x): r""" Returns the convex conjugate of function :math:`F` at :math:`x^{*}`, .. math:: F^{*}(x^{*}) = \underset{x^{*}}{\sup} <x^{*}, x> - F(x) """ raise NotImplementedError
[docs] def proximal_conjugate(self, x, tau, out = None): r"""Returns the proximal operator of the convex conjugate of function :math:`\tau F` at :math:`x^{*}` .. math:: \mathrm{prox}_{\tau F^{*}}(x^{*}) = \underset{z^{*}}{\mathrm{argmin}} \frac{1}{2}\|z^{*} - x^{*}\|^{2} + \tau F^{*}(z^{*}) Due to Moreau’s identity, we have an analytic formula to compute the proximal operator of the convex conjugate :math:`F^{*}` .. math:: \mathrm{prox}_{\tau F^{*}}(x) = x - \tau\mathrm{prox}_{\tau^{-1} F}(\tau^{-1}x) """ try: tmp = x x.divide(tau, out = tmp) except TypeError: tmp = x.divide(tau, dtype=np.float32) if out is None: val = self.proximal(tmp, 1.0/tau) else: self.proximal(tmp, 1.0/tau, out = out) val = out if id(tmp) == id(x): x.multiply(tau, out = x) # CIL issue #1078, cannot use axpby # val.axpby(-tau, 1.0, x, out=val) val.multiply(-tau, out = val) val.add(x, out = val) if out is None: return val
# Algebra for Function Class # Add functions # Subtract functions # Add/Substract with Scalar # Multiply with Scalar
[docs] def __add__(self, other): """ Returns the sum of the functions. Cases: a) the sum of two functions :math:`(F_{1}+F_{2})(x) = F_{1}(x) + F_{2}(x)` b) the sum of a function with a scalar :math:`(F_{1}+scalar)(x) = F_{1}(x) + scalar` """ if isinstance(other, Function): return SumFunction(self, other) elif isinstance(other, (SumScalarFunction, ConstantFunction, Number)): return SumScalarFunction(self, other) else: raise ValueError('Not implemented')
[docs] def __radd__(self, other): """ Making addition commutative. """ return self + other
[docs] def __sub__(self, other): """ Returns the subtraction of the functions.""" return self + (-1) * other
[docs] def __rmul__(self, scalar): """Returns a function multiplied by a scalar.""" return ScaledFunction(self, scalar)
def __mul__(self, scalar): return self.__rmul__(scalar)
[docs] def centered_at(self, center): """ Returns a translated function, namely if we have a function :math:`F(x)` the center is at the origin. TranslateFunction is :math:`F(x - b)` and the center is at point b.""" if center is None: return self else: return TranslateFunction(self, center)
@property def L(self): '''Lipschitz of the gradient of function f. L is positive real number, such that |f'(x) - f'(y)| <= L ||x-y||, assuming f: IG --> R''' return self._L # return self._L @L.setter def L(self, value): '''Setter for Lipschitz constant''' if isinstance(value, (Number,)) and value >= 0: self._L = value else: raise TypeError('The Lipschitz constant is a real positive number')
[docs]class SumFunction(Function): r"""SumFunction represents the sum of :math:`n\geq2` functions .. math:: (F_{1} + F_{2} + ... + F_{n})(\cdot) = F_{1}(\cdot) + F_{2}(\cdot) + ... + F_{n}(\cdot) Parameters ---------- *functions : Functions Functions to set up a :class:`.SumFunction` Raises ------ ValueError If the number of function is strictly less than 2. Examples -------- .. math:: F(x) = \|x\|^{2} + \frac{1}{2}\|x - 1\|^{2} >>> from cil.optimisation.functions import L2NormSquared >>> from cil.framework import ImageGeometry >>> f1 = L2NormSquared() >>> f2 = 0.5 * L2NormSquared(b = ig.allocate(1)) >>> F = SumFunction(f1, f2) .. math:: F(x) = \sum_{i=1}^{50} \|x - i\|^{2} >>> F = SumFunction(*[L2NormSquared(b=i) for i in range(50)]) """
[docs] def __init__(self, *functions ): super(SumFunction, self).__init__() if len(functions) < 2: raise ValueError('At least 2 functions need to be passed') self.functions = functions
@property def L(self): """Returns the Lipschitz constant for the SumFunction .. math:: L = \sum_{i} L_{i} where :math:`L_{i}` is the Lipschitz constant of the smooth function :math:`F_{i}`. """ L = 0. for f in self.functions: if f.L is not None: L += f.L else: L = None break self._L = L return self._L @L.setter def L(self, value): # call base class setter super(SumFunction, self.__class__).L.fset(self, value ) @property def Lmax(self): """Returns the maximum Lipschitz constant for the SumFunction .. math:: L = \max_{i}\{L_{i}\} where :math:`L_{i}` is the Lipschitz constant of the smooth function :math:`F_{i}`. """ l = [] for f in self.functions: if f.L is not None: l.append(f.L) else: l = None break self._Lmax = max(l) return self._Lmax @Lmax.setter def Lmax(self, value): # call base class setter super(SumFunction, self.__class__).Lmax.fset(self, value )
[docs] def __call__(self,x): r"""Returns the value of the sum of functions at :math:`x`. .. math:: (F_{1} + F_{2} + ... + F_{n})(x) = F_{1}(x) + F_{2}(x) + ... + F_{n}(x) """ ret = 0. for f in self.functions: ret += f(x) return ret
[docs] def gradient(self, x, out=None): r"""Returns the value of the sum of the gradient of functions at :math:`x`, if all of them are differentiable. .. math:: (F'_{1} + F'_{2} + ... + F'_{n})(x) = F'_{1}(x) + F'_{2}(x) + ... + F'_{n}(x) """ if out is None: for i,f in enumerate(self.functions): if i == 0: ret = f.gradient(x) else: ret += f.gradient(x) return ret else: for i,f in enumerate(self.functions): if i == 0: f.gradient(x, out=out) else: out += f.gradient(x)
[docs] def __add__(self, other): """ Addition for the SumFunction. * :code:`SumFunction` + :code:`SumFunction` is a :code:`SumFunction`. * :code:`SumFunction` + :code:`Function` is a :code:`SumFunction`. """ if isinstance(other, SumFunction): functions = list(self.functions) + list(other.functions) return SumFunction(*functions) elif isinstance(other, Function): functions = list(self.functions) functions.append(other) return SumFunction(*functions) else: return super(SumFunction, self).__add__(other)
[docs]class ScaledFunction(Function): r""" ScaledFunction represents the scalar multiplication with a Function. Let a function F then and a scalar :math:`\alpha`. If :math:`G(x) = \alpha F(x)` then: 1. :math:`G(x) = \alpha F(x)` ( __call__ method ) 2. :math:`G'(x) = \alpha F'(x)` ( gradient method ) 3. :math:`G^{*}(x^{*}) = \alpha F^{*}(\frac{x^{*}}{\alpha})` ( convex_conjugate method ) 4. :math:`\mathrm{prox}_{\tau G}(x) = \mathrm{prox}_{(\tau\alpha) F}(x)` ( proximal method ) """
[docs] def __init__(self, function, scalar): super(ScaledFunction, self).__init__() if not isinstance (scalar, Number): raise TypeError('expected scalar: got {}'.format(type(scalar))) self.scalar = scalar self.function = function
@property def L(self): if self._L is None: if self.function.L is not None: self._L = abs(self.scalar) * self.function.L else: self._L = None return self._L @L.setter def L(self, value): # call base class setter super(ScaledFunction, self.__class__).L.fset(self, value ) @property def scalar(self): return self._scalar @scalar.setter def scalar(self, value): if isinstance(value, (Number, )): self._scalar = value else: raise TypeError('Expecting scalar type as a number type. Got {}'.format(type(value)))
[docs] def __call__(self,x, out=None): r"""Returns the value of the scaled function. .. math:: G(x) = \alpha F(x) """ return self.scalar * self.function(x)
[docs] def convex_conjugate(self, x): r"""Returns the convex conjugate of the scaled function. .. math:: G^{*}(x^{*}) = \alpha F^{*}(\frac{x^{*}}{\alpha}) """ try: x.divide(self.scalar, out = x) tmp = x except TypeError: tmp = x.divide(self.scalar, dtype=np.float32) val = self.function.convex_conjugate(tmp) if id(tmp) == id(x): x.multiply(self.scalar, out = x) return self.scalar * val
[docs] def gradient(self, x, out=None): r"""Returns the gradient of the scaled function. .. math:: G'(x) = \alpha F'(x) """ if out is None: return self.scalar * self.function.gradient(x) else: self.function.gradient(x, out=out) out *= self.scalar
[docs] def proximal(self, x, tau, out=None): r"""Returns the proximal operator of the scaled function. .. math:: \mathrm{prox}_{\tau G}(x) = \mathrm{prox}_{(\tau\alpha) F}(x) """ return self.function.proximal(x, tau*self.scalar, out=out)
[docs] def proximal_conjugate(self, x, tau, out = None): r"""This returns the proximal operator for the function at x, tau """ try: tmp = x x.divide(tau, out = tmp) except TypeError: tmp = x.divide(tau, dtype=np.float32) if out is None: val = self.function.proximal(tmp, self.scalar/tau ) else: self.function.proximal(tmp, self.scalar/tau, out = out) val = out if id(tmp) == id(x): x.multiply(tau, out = x) # CIL issue #1078, cannot use axpby #val.axpby(-tau, 1.0, x, out=val) val.multiply(-tau, out = val) val.add(x, out = val) if out is None: return val
[docs]class SumScalarFunction(SumFunction): """ SumScalarFunction represents the sum a function with a scalar. .. math:: (F + scalar)(x) = F(x) + scalar Although SumFunction has no general expressions for i) convex_conjugate ii) proximal iii) proximal_conjugate if the second argument is a ConstantFunction then we can derive the above analytically. """
[docs] def __init__(self, function, constant): super(SumScalarFunction, self).__init__(function, ConstantFunction(constant)) self.constant = constant self.function = function
[docs] def convex_conjugate(self,x): r""" Returns the convex conjugate of a :math:`(F+scalar)` .. math:: (F+scalar)^{*}(x^{*}) = F^{*}(x^{*}) - scalar """ return self.function.convex_conjugate(x) - self.constant
[docs] def proximal(self, x, tau, out=None): """ Returns the proximal operator of :math:`F+scalar` .. math:: \mathrm{prox}_{\tau (F+scalar)}(x) = \mathrm{prox}_{\tau F} """ return self.function.proximal(x, tau, out=out)
@property def L(self): if self._L is None: if self.function.L is not None: self._L = self.function.L else: self._L = None return self._L @L.setter def L(self, value): # call base class setter super(SumScalarFunction, self.__class__).L.fset(self, value )
[docs]class ConstantFunction(Function): r""" ConstantFunction: :math:`F(x) = constant, constant\in\mathbb{R}` """
[docs] def __init__(self, constant = 0): self.constant = constant super(ConstantFunction, self).__init__(L=0)
[docs] def __call__(self,x): """ Returns the value of the function, :math:`F(x) = constant`""" return self.constant
[docs] def gradient(self, x, out=None): """ Returns the value of the gradient of the function, :math:`F'(x)=0`""" if out is None: return x * 0. else: out.fill(0)
[docs] def convex_conjugate(self, x): r""" The convex conjugate of constant function :math:`F(x) = c\in\mathbb{R}` is .. math:: F(x^{*}) = \begin{cases} -c, & if x^{*} = 0\\ \infty, & \mbox{otherwise} \end{cases} However, :math:`x^{*} = 0` only in the limit of iterations, so in fact this can be infinity. We do not want to have inf values in the convex conjugate, so we have to penalise this value accordingly. The following penalisation is useful in the PDHG algorithm, when we compute primal & dual objectives for convergence purposes. .. math:: F^{*}(x^{*}) = \sum \max\{x^{*}, 0\} """ return x.maximum(0).sum()
[docs] def proximal(self, x, tau, out=None): """Returns the proximal operator of the constant function, which is the same element, i.e., .. math:: \mathrm{prox}_{\tau F}(x) = x """ if out is None: return x.copy() else: out.fill(x)
@property def constant(self): return self._constant @constant.setter def constant(self, value): if not isinstance (value, Number): raise TypeError('expected scalar: got {}'.format(type(value))) self._constant = value @property def L(self): return 0.
[docs] def __rmul__(self, other): '''defines the right multiplication with a number''' if not isinstance (other, Number): raise NotImplemented constant = self.constant * other return ConstantFunction(constant)
[docs]class ZeroFunction(ConstantFunction): """ ZeroFunction represents the zero function, :math:`F(x) = 0` """
[docs] def __init__(self): super(ZeroFunction, self).__init__(constant = 0.)
[docs]class TranslateFunction(Function): r""" TranslateFunction represents the translation of function F with respect to the center b. Let a function F and consider :math:`G(x) = F(x - center)`. Function F is centered at 0, whereas G is centered at point b. If :math:`G(x) = F(x - b)` then: 1. :math:`G(x) = F(x - b)` ( __call__ method ) 2. :math:`G'(x) = F'(x - b)` ( gradient method ) 3. :math:`G^{*}(x^{*}) = F^{*}(x^{*}) + <x^{*}, b >` ( convex_conjugate method ) 4. :math:`\mathrm{prox}_{\tau G}(x) = \mathrm{prox}_{\tau F}(x - b) + b` ( proximal method ) """
[docs] def __init__(self, function, center): try: L = function.L except NotImplementedError as nie: L = None super(TranslateFunction, self).__init__(L = L) self.function = function self.center = center
[docs] def __call__(self, x): r"""Returns the value of the translated function. .. math:: G(x) = F(x - b) """ try: x.subtract(self.center, out = x) tmp = x except TypeError: tmp = x.subtract(self.center, dtype=np.float32) val = self.function(tmp) if id(tmp) == id(x): x.add(self.center, out = x) return val
[docs] def gradient(self, x, out = None): r"""Returns the gradient of the translated function. .. math:: G'(x) = F'(x - b) """ try: x.subtract(self.center, out = x) tmp = x except TypeError: tmp = x.subtract(self.center, dtype=np.float32) if out is None: val = self.function.gradient(tmp) else: self.function.gradient(tmp, out = out) if id(tmp) == id(x): x.add(self.center, out = x) if out is None: return val
[docs] def proximal(self, x, tau, out = None): r"""Returns the proximal operator of the translated function. .. math:: \mathrm{prox}_{\tau G}(x) = \mathrm{prox}_{\tau F}(x-b) + b """ try: x.subtract(self.center, out = x) tmp = x except TypeError: tmp = x.subtract(self.center, dtype=np.float32) if out is None: val = self.function.proximal(tmp, tau) val.add(self.center, out = val) else: self.function.proximal(tmp, tau, out = out) out.add(self.center, out = out) if id(tmp) == id(x): x.add(self.center, out = x) if out is None: return val
[docs] def convex_conjugate(self, x): r"""Returns the convex conjugate of the translated function. .. math:: G^{*}(x^{*}) = F^{*}(x^{*}) + <x^{*}, b > """ return self.function.convex_conjugate(x) + self.center.dot(x)
if __name__ == "__main__": F1 = Function() F2 = Function() res1 = F1 + F2 print("sum two function", res1.__class__) # SumFunction res2 = F1 + 5 print("sum function and scalar",res2.__class__) # SumScalarFunction res3 = 5 + F1 print("sum scalar and function",res3.__class__) # SumScalarFunction res4 = F1 + ConstantFunction(5) print("sum function and constant",res4.__class__) # SumFunction res4 = ConstantFunction(5) + 5 print("sum constant and function",res4.__class__) # SumScalarFunction res4 = ZeroFunction() + 5 print("sum zero function and function",res4.__class__) # SumScalarFunction res3 = res1 + (F1+F2) print(res3.__class__) from cil.optimisation.functions import L2NormSquared from cil.framework import ImageGeometry ig = ImageGeometry(3,4) f1 = L2NormSquared() f2 = 0.5 * L2NormSquared(b = 1) F = SumFunction(f1, f2) x = ig.allocate(5) F(x) G = SumFunction(*[f1]*4) len(G.functions) F = SumFunction(*[L2NormSquared(b=ig.allocate(i)) for i in range(10)]) print(len(F.functions)) # L = F.L # print(res, L)