Source code for cil.optimisation.functions.OperatorCompositionFunction

#  Copyright 2019 United Kingdom Research and Innovation
#  Copyright 2019 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt

from cil.optimisation.functions import Function
from cil.optimisation.operators import Operator, ScaledOperator

import warnings

[docs] class OperatorCompositionFunction(Function): """ Composition of a function with an operator as : :math:`(F \circ A)(x) = F(Ax)` :parameter function: :code:`Function` F :parameter operator: :code:`Operator` A For general operator, we have no explicit formulas for convex_conjugate, proximal and proximal_conjugate """ def __init__(self, function, operator): '''creator :param A: operator :type A: :code:`Operator` :param f: function :type f: :code:`Function` ''' super(OperatorCompositionFunction, self).__init__() self.function = function self.operator = operator @property def L(self): if self._L is None: try: self._L = self.function.L * (self.operator.norm() ** 2) except ValueError as ve: self._L = None return self._L def __call__(self, x): """ Returns :math:`F(Ax)` """ return self.function(self.operator.direct(x))
[docs] def gradient(self, x, out=None): """ Return the gradient of :math:`F(Ax)`, :math:`(F(Ax))' = A^{T}F'(Ax)` """ tmp = self.operator.range_geometry().allocate() self.operator.direct(x, out=tmp) self.function.gradient(tmp, out=tmp) return self.operator.adjoint(tmp, out=out)