#  Copyright 2020 United Kingdom Research and Innovation
#  Copyright 2020 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
try:
    from ccpi.filters import regularisers
    from ccpi.filters.TV import TV_ENERGY
except ImportError as exc:
    raise ImportError('Please `conda install "ccpi::ccpi-regulariser>=24"`') from exc
from cil.framework import DataOrder
from cil.framework import DataContainer
from cil.optimisation.functions import Function
import numpy as np
import warnings
from numbers import Number
class RegulariserFunction(Function):
    def proximal(self, x, tau, out=None):
        r""" Generic proximal method for a RegulariserFunction
        .. math:: \mathrm{prox}_{\tau f}(x) := \argmin_{z} f(x) + \frac{1}{2}\|z - x \|^{2}
        Parameters
        ----------
        x : DataContainer
            Input of the proximal operator
        tau : Number
            Positive parameter of the proximal operator
        out : DataContainer
            Output :class:`Datacontainer` in which the result is placed.
        Note
        ----
        If the :class:`ImageData` contains complex data, rather than the default `float32`, the regularisation
        is run independently on the real and imaginary part.
        """
        self.check_input(x)
        arr = x.as_array()
        if np.iscomplexobj(arr):
            # do real and imag part indep
            in_arr = np.asarray(arr.real, dtype=np.float32, order='C')
            res, info = self.proximal_numpy(in_arr, tau)
            arr.real = res[:]
            in_arr = np.asarray(arr.imag, dtype=np.float32, order='C')
            res, info = self.proximal_numpy(in_arr, tau)
            arr.imag = res[:]
            self.info = info
            if out is not None:
                out.fill(arr)
            else:
                out = x.copy()
                out.fill(arr)
                return out
        else:
            arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
            res, info = self.proximal_numpy(arr, tau)
            self.info = info
            if out is not None:
                out.fill(res)
            else:
                out = x.copy()
                out.fill(res)
                return out
    def proximal_numpy(self, xarr, tau):
        raise NotImplementedError('Please implement proximal_numpy')
    def check_input(self, input):
        pass
class TV_Base(RegulariserFunction):
    r""" Total Variation regulariser
    .. math:: TV(u) = \alpha \|\nabla u\|_{2,1}
    Parameters
    ----------
    strong_convexity_constant : Number
                              Positive parameter that allows Total variation regulariser to be strongly convex. Default = 0.
    Note
    ----
    By definition, Total variation is a convex function. However,
    adding a strongly convex term makes it a strongly convex function.
    Then, we say that `TV` is a :math:`\gamma>0` strongly convex function i.e.,
    .. math:: TV(u) = \alpha \|\nabla u\|_{2,1} + \frac{\gamma}{2}\|u\|^{2}
    """
    def __init__(self, strong_convexity_constant = 0):
        self.strong_convexity_constant = strong_convexity_constant
    def __call__(self,x):
        in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
        EnergyValTV = TV_ENERGY(in_arr, in_arr, self.alpha, 2)
        if self.strong_convexity_constant>0:
            return 0.5*EnergyValTV[0] + (self.strong_convexity_constant/2)*x.squared_norm()
        else:
            return 0.5*EnergyValTV[0]
    def convex_conjugate(self,x):
        return 0.0
[docs]class FGP_TV(TV_Base):
    r""" Fast Gradient Projection Total Variation (FGP_TV)
        The :class:`FGP_TV` computes the proximal operator of the Total variation regulariser
        .. math:: \mathrm{prox}_{\tau (\alpha TV)}(x) = \underset{z}{\mathrm{argmin}} \,\alpha\,\mathrm{TV}(z) + \frac{1}{2}\|z - x\|^{2} .
        The algorithm used for the proximal operator of TV is the Fast Gradient Projection algorithm
        applied to the _dual problem_ of the above problem, see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`.
        Parameters
        ----------
        alpha : :obj:`Number` (positive), default = 1.0 .
                Total variation regularisation parameter.
        max_iteration : :obj:`int`. Default = 100 .
                Maximum number of iterations for the Fast Gradient Projection algorithm.
        isotropic : :obj:`boolean`. Default = True .
                    Isotropic or Anisotropic definition of the Total variation regulariser.
                    .. math:: |x|_{2} = \sqrt{x_{1}^{2} + x_{2}^{2}},\, (\mbox{isotropic})
                    .. math:: |x|_{1} = |x_{1}| + |x_{2}|\, (\mbox{anisotropic})
        nonnegativity : :obj:`boolean`. Default = True .
                        Non-negativity constraint for the solution of the FGP algorithm.
        tolerance : :obj:`float`, Default = 0 .
                    Stopping criterion for the FGP algorithm.
                    .. math:: \|x^{k+1} - x^{k}\|_{2} < \mathrm{tolerance}
        device : :obj:`str`, Default = 'cpu' .
                FGP_TV algorithm runs on `cpu` or `gpu`.
        strong_convexity_constant : :obj:`float`, default = 0
                A strongly convex term weighted by the :code:`strong_convexity_constant` (:math:`\gamma`) parameter is added to the Total variation.
                Now the :code:`TotalVariation` function is :math:`\gamma` - strongly convex and the proximal operator is
                .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\tau}\|u - b\|^{2} + \mathrm{TV}(u) + \frac{\gamma}{2}\|u\|^{2} \Leftrightarrow
                .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\frac{\tau}{1+\gamma\tau}}\|u - \frac{b}{1+\gamma\tau}\|^{2} + \mathrm{TV}(u)
        Examples
        --------
        .. math:: \underset{u\qeq0}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + \alpha TV(u)
        >>> G = alpha * FGP_TV(max_iteration=100, device='gpu')
        >>> sol = G.proximal(b)
        Note
        ----
        The :class:`FGP_TV` regularisation does not incorparate information on the :class:`ImageGeometry`, i.e., pixel/voxel size.
        Therefore a rescaled parameter should be used to match the same solution computed using :class:`~cil.optimisation.functions.TotalVariation`.
        >>> G1 = (alpha/ig.voxel_size_x) * FGP_TV(max_iteration=100, device='gpu')
        >>> G2 = alpha * TotalVariation(max_iteration=100, lower=0.)
        See Also
        --------
        :class:`~cil.optimisation.functions.TotalVariation`
        """
    def __init__(self, alpha=1, max_iteration=100, tolerance=0, isotropic=True, nonnegativity=True, device='cpu', strong_convexity_constant=0):
        if isotropic == True:
            self.methodTV = 0
        else:
            self.methodTV = 1
        if nonnegativity == True:
            self.nonnegativity = 1
        else:
            self.nonnegativity = 0
        self.alpha = alpha
        self.max_iteration = max_iteration
        self.tolerance = tolerance
        self.nonnegativity = nonnegativity
        self.device = device
        super(FGP_TV, self).__init__(strong_convexity_constant=strong_convexity_constant)
    def _fista_on_dual_rof(self, in_arr, tau):
        r""" Implements the Fast Gradient Projection algorithm on the dual problem
        of the Total Variation Denoising problem (ROF).
        """
        info = np.zeros((2,), dtype=np.float32)
        res = regularisers.FGP_TV(\
              
in_arr,\
              
self.alpha * tau,\
              
self.max_iteration,\
              
self.tolerance,\
              
self.methodTV,\
              
self.nonnegativity,\
              
infovector = info,
              device = self.device)
        return res, info
    def proximal_numpy(self, in_arr, tau):
        if self.strong_convexity_constant>0:
            strongly_convex_factor = (1 + tau * self.strong_convexity_constant)
            in_arr /= strongly_convex_factor
            tau /= strongly_convex_factor
        solution = self._fista_on_dual_rof(in_arr, tau)
        if self.strong_convexity_constant>0:
            in_arr *= strongly_convex_factor
            tau *= strongly_convex_factor
        return solution
    def __rmul__(self, scalar):
        '''Define the multiplication with a scalar
        this changes the regularisation parameter in the plugin'''
        if not isinstance (scalar, Number):
            raise NotImplemented
        else:
            self.alpha *= scalar
            return self
    def check_input(self, input):
        if len(input.shape) > 3:
            raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) 
[docs]class TGV(RegulariserFunction):
[docs]    def __init__(self, alpha=1, gamma=1, max_iteration=100, tolerance=0, device='cpu' , **kwargs):
        '''Creator of Total Generalised Variation Function
        :param alpha: regularisation parameter
        :type alpha: number, default 1
        :param gamma: ratio of TGV terms
        :type gamma: number, default 1, can range between 1 and 2
        :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached
        :type max_iteration: integer, default 100
        :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion.
        :type tolerance: float, default 0
        :param device: determines if the code runs on CPU or GPU
        :type device: string, default 'cpu', can be 'gpu' if GPU is installed
        '''
        self.alpha = alpha
        self.gamma = gamma
        self.max_iteration = max_iteration
        self.tolerance = tolerance
        self.device = device
        if kwargs.get('iter_TGV', None) is not None:
            # raise ValueError('iter_TGV parameter has been superseded by num_iter. Use that instead.')
            self.num_iter = kwargs.get('iter_TGV') 
[docs]    def __call__(self,x):
        warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
        return np.nan 
    @property
    def gamma(self):
        return self.__gamma
    @gamma.setter
    def gamma(self, value):
        if value <= 2 and value >= 1:
            self.__gamma = value
    @property
    def alpha2(self):
        return self.alpha1 * self.gamma
    @property
    def alpha1(self):
        return 1.
    def proximal_numpy(self, in_arr, tau):
        info = np.zeros((2,), dtype=np.float32)
        res = regularisers.TGV(in_arr,
              self.alpha * tau,
              self.alpha1,
              self.alpha2,
              self.max_iteration,
              self.LipshitzConstant,
              self.tolerance,
              infovector = info,
              device = self.device)
        # info: return number of iteration and reached tolerance
        # https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/master/src/Core/regularisers_CPU/TGV_core.c#L168
        # Stopping Criteria  || u^k - u^(k-1) ||_{2} / || u^{k} ||_{2}
        return res, info
[docs]    def convex_conjugate(self, x):
        warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
        return np.nan 
[docs]    def __rmul__(self, scalar):
        '''Define the multiplication with a scalar
        this changes the regularisation parameter in the plugin'''
        if not isinstance (scalar, Number):
            raise NotImplemented
        else:
            self.alpha *= scalar
            return self 
        # f = TGV()
        # f = alpha * f
    def check_input(self, input):
        if len(input.shape) == 2:
            self.LipshitzConstant = 12
        elif len(input.shape) == 3:
            self.LipshitzConstant = 16 # Vaggelis to confirm
        else:
            raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) 
[docs]class FGP_dTV(RegulariserFunction):
    '''Creator of FGP_dTV Function
        :param reference: reference image
        :type reference: ImageData
        :param alpha: regularisation parameter
        :type alpha: number, default 1
        :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached
        :type max_iteration: integer, default 100
        :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion.
        :type tolerance: float, default 0
        :param eta: smoothing constant to calculate gradient of the reference
        :type eta: number, default 0.01
        :param isotropic: Whether it uses L2 (isotropic) or L1 (anisotropic) norm
        :type isotropic: boolean, default True, can range between 1 and 2
        :param nonnegativity: Whether to add the non-negativity constraint
        :type nonnegativity: boolean, default True
        :param device: determines if the code runs on CPU or GPU
        :type device: string, default 'cpu', can be 'gpu' if GPU is installed
        '''
[docs]    def __init__(self, reference, alpha=1, max_iteration=100,
                 tolerance=0, eta=0.01, isotropic=True, nonnegativity=True, device='cpu'):
        if isotropic == True:
            self.methodTV = 0
        else:
            self.methodTV = 1
        if nonnegativity == True:
            self.nonnegativity = 1
        else:
            self.nonnegativity = 0
        self.alpha = alpha
        self.max_iteration = max_iteration
        self.tolerance = tolerance
        self.device = device # string for 'cpu' or 'gpu'
        self.reference = np.asarray(reference.as_array(), dtype=np.float32)
        self.eta = eta 
[docs]    def __call__(self,x):
        warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
        return np.nan 
    def proximal_numpy(self, in_arr, tau):
        
        info = np.zeros((2,), dtype=np.float32)
        res = regularisers.FGP_dTV(\
                
in_arr,\
                
self.reference,\
                
self.alpha * tau,\
                
self.max_iteration,\
                
self.tolerance,\
                
self.eta,\
                
self.methodTV,\
                
self.nonnegativity,\
                
infovector = info,
                device = self.device)
        return res, info
[docs]    def convex_conjugate(self, x):
        warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
        return np.nan 
[docs]    def __rmul__(self, scalar):
        '''Define the multiplication with a scalar
        this changes the regularisation parameter in the plugin'''
        if not isinstance (scalar, Number):
            raise NotImplemented
        else:
            self.alpha *= scalar
            return self 
    def check_input(self, input):
        if len(input.shape) > 3:
            raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) 
[docs]class TNV(RegulariserFunction):
[docs]    def __init__(self,alpha=1, max_iteration=100, tolerance=0):
        '''Creator of TNV Function
        :param alpha: regularisation parameter
        :type alpha: number, default 1
        :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached
        :type max_iteration: integer, default 100
        :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion.
        :type tolerance: float, default 0
        '''
        # set parameters
        self.alpha = alpha
        self.max_iteration = max_iteration
        self.tolerance = tolerance 
[docs]    def __call__(self,x):
        warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
        return np.nan 
    def proximal_numpy(self, in_arr, tau):
        # remove any dimension of size 1
        in_arr = np.squeeze(in_arr)
    
        res = regularisers.TNV(in_arr,
              self.alpha * tau,
              self.max_iteration,
              self.tolerance)
        return res, []
[docs]    def convex_conjugate(self, x):
        warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
        return np.nan 
[docs]    def __rmul__(self, scalar):
        '''Define the multiplication with a scalar
        this changes the regularisation parameter in the plugin'''
        if not isinstance (scalar, Number):
            raise NotImplemented
        else:
            self.alpha *= scalar
            return self