# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# - CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
# - Daniel Deidda (National Physical Laboratory, UK)
# - Claire Delplancke (Electricite de France, Research and Development)
# - Ashley Gillman (Australian e-Health Res. Ctr., CSIRO, Brisbane, Queensland, Australia)
# - Zeljko Kereta (Department of Computer Science, University College London, UK)
# - Evgueni Ovtchinnikov (STFC - UKRI)
# - Georg Schramm (Department of Imaging and Pathology, Division of Nuclear Medicine, KU Leuven, Leuven, Belgium)
from .ApproximateGradientSumFunction import ApproximateGradientSumFunction
from .Function import SumFunction
[docs]
class SGFunction(ApproximateGradientSumFunction):
r"""
Stochastic gradient function, a child class of :code:`ApproximateGradientSumFunction`, which defines from a list of functions, :math:`{f_0,...,f_{n-1}}` a `SumFunction`, :math:`f_0+...+f_{n-1}` where each time the `gradient` is called, the :code:`sampler` provides an index, :math:`i \in {0,...,n-1}`
and the :code:`gradient` method returns the approximate gradient :math:`n \nabla_x f_i(x)`. This can be used with the :code:`cil.optimisation.algorithms` algorithm :code:`GD` to give a stochastic gradient descent algorithm.
Parameters
-----------
functions : `list` of functions
A list of functions: :code:`[f_{0}, f_{1}, ..., f_{n-1}]`. Each function is assumed to be smooth with an implemented :func:`~Function.gradient` method. Although CIL does not define a domain of a :code:`Function`, all functions are supposed to have the same domain. The number of functions must be strictly greater than 1.
sampler: An instance of a CIL Sampler class ( :meth:`~optimisation.utilities.sampler`) or of another class which has a :code:`__next__` function implemented to output integers in {0,...,n-1}.
This sampler is called each time gradient is called and sets the internal :code:`function_num` passed to the :code:`approximate_gradient` function. Default is :code:`Sampler.random_with_replacement(len(functions))`.
"""
def __init__(self, functions, sampler=None):
super(SGFunction, self).__init__(functions, sampler)
[docs]
def approximate_gradient(self, x, function_num, out=None):
r""" Returns the gradient of the function at index `function_num` at :code:`x`.
Parameters
----------
x : DataContainer
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.
function_num: `int`
Between 0 and the number of functions in the list
Returns
--------
DataContainer
the value of the approximate gradient of the sum function at :code:`x` given a `function_number` in {0,...,len(functions)-1}
"""
if self.function_num >= self.num_functions or self.function_num<0 :
raise IndexError(
'The sampler has outputted an index larger than the number of functions to sample from. Please ensure your sampler samples from {0,1,...,len(functions)-1} only.')
# compute gradient of the function indexed by function_num
if out is None:
out = self.functions[function_num].gradient(x)
else:
self.functions[function_num].gradient(x, out = out)
# scale wrt number of functions
out*=self.num_functions
return out