#  Copyright 2019 United Kingdom Research and Innovation
#  Copyright 2019 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.functions import IndicatorBox
from cil.framework import BlockDataContainer
from cil.utilities.errors import InPlaceError
import numpy
import logging
log = logging.getLogger(__name__)
[docs]class SIRT(Algorithm):
    r"""Simultaneous Iterative Reconstruction Technique, see :cite:`Kak2001`.
    Simultaneous Iterative Reconstruction Technique (SIRT) solves
    the following problem
    .. math:: A x = b
    The SIRT algorithm is
    .. math:: x^{k+1} =  \mathrm{proj}_{C}( x^{k} + \omega * D ( A^{T} ( M * (b - Ax^{k}) ) ) ),
    where,
    :math:`M = \frac{1}{A*\mathbb{1}}`,
    :math:`D = \frac{1}{A^{T}\mathbb{1}}`,
    :math:`\mathbb{1}` is a :code:`DataContainer` of ones,
    :math:`\mathrm{prox}_{C}` is the projection over a set :math:`C`,
    and :math:`\omega` is the relaxation parameter.
    Parameters
    ----------
    initial : DataContainer, default = None
        Starting point of the algorithm, default value = Zero DataContainer
    operator : LinearOperator
        The operator A.
    data : DataContainer
        The data b.
    lower : :obj:`float`, default = None
        Lower bound constraint
    upper : :obj:`float`, default = None
        Upper bound constraint
    constraint : Function, default = None
        A function with :code:`proximal` method, e.g., :class:`.IndicatorBox` function and :meth:`.IndicatorBox.proximal`,
        or :class:`.TotalVariation` function and :meth:`.TotalVariation.proximal`.
    kwargs:
        Keyword arguments used from the base class :class:`.Algorithm`.
    Note
    ----
    If :code:`constraint` is not passed, :code:`lower` and :code:`upper` are used to create an :class:`.IndicatorBox` and apply its :code:`proximal`.
    If :code:`constraint` is passed, :code:`proximal` method is required to be implemented.
    Note
    ----
    The preconditioning arrays (weights) :code:`M` and :code:`D` used in SIRT are defined as
    .. math:: M = \frac{1}{A*\mathbb{1}} = \frac{1}{\sum_{j}a_{i,j}}
    .. math:: D = \frac{1}{A*\mathbb{1}} = \frac{1}{\sum_{i}a_{i,j}}
    Examples
    --------
    .. math:: \underset{x}{\mathrm{argmin}} \frac{1}{2}\| x - d\|^{2}
    >>> sirt = SIRT(initial = ig.allocate(0), operator = A, data = d, max_iteration = 5)
    """
    def __init__(self, initial, operator, data, lower=None, upper=None, constraint=None, **kwargs):
        super(SIRT, self).__init__(**kwargs)
        self.set_up(initial=initial, operator=operator, data=data, lower=lower, upper=upper, constraint=constraint)
[docs]    def set_up(self, initial, operator, data, lower=None, upper=None, constraint=None):
        """Initialisation of the algorithm"""
        log.info("%s setting up", self.__class__.__name__)
        self.x = initial.copy()
        self.tmp_x = self.x * 0.0
        self.operator = operator
        self.data = data
        self.r = data.copy()
        self.constraint = constraint
        if constraint is None:
            if lower is not None or upper is not None:
                # IndicatorBox accepts None for lower and/or upper
                self.constraint=IndicatorBox(lower=lower,upper=upper)
        self._relaxation_parameter = 1
        # Set up scaling matrices D and M.
        self._set_up_weights()
        self.configured = True
        log.info("%s configured", self.__class__.__name__) 
    @property
    def relaxation_parameter(self):
        return self._relaxation_parameter
    @property
    def D(self):
        return self._Dscaled / self._relaxation_parameter
[docs]    def set_relaxation_parameter(self, value=1.0):
        """Set the relaxation parameter :math:`\omega`
        Parameters
        ----------
        value : float
            The relaxation parameter to be applied to the update. Must be between 0 and 2 to guarantee asymptotic convergence.
        """
        if value <= 0 or value >= 2:
            raise ValueError("Expected relaxation parameter to be in range 0-2. Got {}".format(value))
        self._relaxation_parameter = value
        self._set_up_weights()
        self._Dscaled *= self._relaxation_parameter 
    def _set_up_weights(self):
        self.M = 1./self.operator.direct(self.operator.domain_geometry().allocate(value=1.0))
        self._Dscaled = 1./self.operator.adjoint(self.operator.range_geometry().allocate(value=1.0))
        for arr in [self.M, self._Dscaled]:
            self._remove_nan_or_inf(arr, replace_with=1.0)
    def _remove_nan_or_inf(self, datacontainer, replace_with=1.0):
        """Replace nan and inf in datacontainer with a given value.
        Parameters:
        -------------
        datacontainer: DataContainer, BlockDataContainer
        replace_with: float, default 1.0
            Value to replace elements that evaluate to NaN or inf
        In case the input datacontainer is a :code:`BlockDataContainer` the substitution is executed for each container in the :code:`BlockDataContainer`.
        """
        if isinstance(datacontainer, BlockDataContainer):
            for block in datacontainer.containers:
                self._remove_nan_or_inf(block, replace_with=replace_with)
            return
        tmp = datacontainer.as_array()
        numpy.nan_to_num(tmp, copy=False, nan=replace_with, posinf=replace_with, neginf=replace_with)
        datacontainer.fill(tmp)
[docs]    def update(self):
        r""" Performs a single iteration of the SIRT algorithm
        .. math:: x^{k+1} =  \mathrm{proj}_{C}( x^{k} + \omega * D ( A^{T} ( M * (b - Ax) ) ) )
        """
        # self.r = self.data - self.operator.direct(self.x)
        self.operator.direct(self.x, out=self.r)
        self.r.sapyb(-1, self.data, 1.0, out=self.r)
        # self.D is prescaled by _relaxation_parameter (default 1)
        self.r *= self.M
        self.operator.adjoint(self.r, out=self.tmp_x)
        self.x.sapyb(1.0, self.tmp_x, self._Dscaled, out=self.x)
        if self.constraint is not None:
            try:
                self.constraint.proximal(self.x, tau=1, out=self.x)
            except InPlaceError:
                self.x=self.constraint.proximal(self.x, tau=1) 
[docs]    def update_objective(self):
        r"""Returns the objective
        .. math:: \frac{1}{2}\|A x - b\|^{2}
        """
        self.loss.append(0.5*self.r.squared_norm())