Source code for cil.plugins.ccpi_regularisation.functions.regularisers
# Copyright 2020 United Kingdom Research and Innovation# Copyright 2020 The University of Manchester## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.## Authors:# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txttry:fromccpi.filtersimportregularisersfromccpi.filters.TVimportTV_ENERGYexceptImportErrorasexc:raiseImportError('Please `conda install "ccpi::ccpi-regulariser>=24.0.1"`')fromexcfromcil.frameworkimportDataContainerfromcil.framework.labelsimportImageDimensionfromcil.optimisation.functionsimportFunctionimportnumpyasnpimportwarningsfromnumbersimportNumberclassRegulariserFunction(Function):defproximal(self,x,tau,out=None):r""" Generic proximal method for a RegulariserFunction .. math:: \mathrm{prox}_{\tau f}(x) := \argmin_{z} f(x) + \frac{1}{2}\|z - x \|^{2} Parameters ---------- x : DataContainer Input of the proximal operator tau : Number Positive parameter of the proximal operator out : DataContainer Output :class:`Datacontainer` in which the result is placed. Note ---- If the :class:`ImageData` contains complex data, rather than the default `float32`, the regularisation is run independently on the real and imaginary part. """self.check_input(x)arr=x.as_array()ifnp.iscomplexobj(arr):# do real and imag part indepin_arr=np.asarray(arr.real,dtype=np.float32,order='C')res,info=self.proximal_numpy(in_arr,tau)arr.real=res[:]in_arr=np.asarray(arr.imag,dtype=np.float32,order='C')res,info=self.proximal_numpy(in_arr,tau)arr.imag=res[:]self.info=infoifoutisnotNone:out.fill(arr)else:out=x.copy()out.fill(arr)returnoutelse:arr=np.asarray(x.as_array(),dtype=np.float32,order='C')res,info=self.proximal_numpy(arr,tau)self.info=infoifoutisnotNone:out.fill(res)else:out=x.copy()out.fill(res)returnoutdefproximal_numpy(self,xarr,tau):raiseNotImplementedError('Please implement proximal_numpy')defcheck_input(self,input):passclassTV_Base(RegulariserFunction):r""" Total Variation regulariser .. math:: TV(u) = \alpha \|\nabla u\|_{2,1} Parameters ---------- strong_convexity_constant : Number Positive parameter that allows Total variation regulariser to be strongly convex. Default = 0. Note ---- By definition, Total variation is a convex function. However, adding a strongly convex term makes it a strongly convex function. Then, we say that `TV` is a :math:`\gamma>0` strongly convex function i.e., .. math:: TV(u) = \alpha \|\nabla u\|_{2,1} + \frac{\gamma}{2}\|u\|^{2} """def__init__(self,strong_convexity_constant=0):self.strong_convexity_constant=strong_convexity_constantdef__call__(self,x):in_arr=np.asarray(x.as_array(),dtype=np.float32,order='C')EnergyValTV=TV_ENERGY(in_arr,in_arr,self.alpha,2)ifself.strong_convexity_constant>0:return0.5*EnergyValTV[0]+(self.strong_convexity_constant/2)*x.squared_norm()else:return0.5*EnergyValTV[0]defconvex_conjugate(self,x):return0.0
[docs]classFGP_TV(TV_Base):r""" Fast Gradient Projection Total Variation (FGP_TV) The :class:`FGP_TV` computes the proximal operator of the Total variation regulariser .. math:: \mathrm{prox}_{\tau (\alpha TV)}(x) = \underset{z}{\mathrm{argmin}} \,\alpha\,\mathrm{TV}(z) + \frac{1}{2}\|z - x\|^{2} . The algorithm used for the proximal operator of TV is the Fast Gradient Projection algorithm applied to the _dual problem_ of the above problem, see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`. Note ----- In CIL Version 24.1.0 we change the default value of nonnegativity to False. This means non-negativity is not enforced by default. Parameters ---------- alpha : :obj:`Number` (positive), default = 1.0 . Total variation regularisation parameter. max_iteration : :obj:`int`. Default = 100 . Maximum number of iterations for the Fast Gradient Projection algorithm. isotropic : :obj:`boolean`. Default = True . Isotropic or Anisotropic definition of the Total variation regulariser. .. math:: |x|_{2} = \sqrt{x_{1}^{2} + x_{2}^{2}},\, (\mbox{isotropic}) .. math:: |x|_{1} = |x_{1}| + |x_{2}|\, (\mbox{anisotropic}) nonnegativity : :obj:`boolean`. Default = False . Non-negativity constraint for the solution of the FGP algorithm. tolerance : :obj:`float`, Default = 0 . Stopping criterion for the FGP algorithm. .. math:: \|x^{k+1} - x^{k}\|_{2} < \mathrm{tolerance} device : :obj:`str`, Default = 'cpu' . FGP_TV algorithm runs on `cpu` or `gpu`. strong_convexity_constant : :obj:`float`, default = 0 A strongly convex term weighted by the :code:`strong_convexity_constant` (:math:`\gamma`) parameter is added to the Total variation. Now the :code:`TotalVariation` function is :math:`\gamma` - strongly convex and the proximal operator is .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\tau}\|u - b\|^{2} + \mathrm{TV}(u) + \frac{\gamma}{2}\|u\|^{2} \Leftrightarrow .. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\frac{\tau}{1+\gamma\tau}}\|u - \frac{b}{1+\gamma\tau}\|^{2} + \mathrm{TV}(u) Examples -------- .. math:: \underset{u\qeq0}{\mathrm{argmin}} \frac{1}{2}\|u - b\|^{2} + \alpha TV(u) >>> G = alpha * FGP_TV(max_iteration=100, device='gpu') >>> sol = G.proximal(b) Note ---- The :class:`FGP_TV` regularisation does not incorparate information on the :class:`ImageGeometry`, i.e., pixel/voxel size. Therefore a rescaled parameter should be used to match the same solution computed using :class:`~cil.optimisation.functions.TotalVariation`. >>> G1 = (alpha/ig.voxel_size_x) * FGP_TV(max_iteration=100, device='gpu') >>> G2 = alpha * TotalVariation(max_iteration=100, lower=0.) See Also -------- :class:`~cil.optimisation.functions.TotalVariation` """def__init__(self,alpha=1,max_iteration=100,tolerance=0,isotropic=True,nonnegativity=None,device='cpu',strong_convexity_constant=0):ifisotropic==True:self.methodTV=0else:self.methodTV=1ifnonnegativityisNone:# Deprecate this warning in future versions and allow nonnegativity to be default False in the init.warnings.warn('Note that the default behaviour now sets the nonnegativity constraint to False ',UserWarning,stacklevel=2)nonnegativity=Falseifnonnegativity==True:self.nonnegativity=1else:self.nonnegativity=0self.alpha=alphaself.max_iteration=max_iterationself.tolerance=toleranceself.nonnegativity=nonnegativityself.device=devicesuper(FGP_TV,self).__init__(strong_convexity_constant=strong_convexity_constant)def_fista_on_dual_rof(self,in_arr,tau):r""" Implements the Fast Gradient Projection algorithm on the dual problem of the Total Variation Denoising problem (ROF). """info=np.zeros((2,),dtype=np.float32)res=regularisers.FGP_TV(\
in_arr,\
self.alpha*tau,\
self.max_iteration,\
self.tolerance,\
self.methodTV,\
self.nonnegativity,\
infovector=info,device=self.device)returnres,infodefproximal_numpy(self,in_arr,tau):ifself.strong_convexity_constant>0:strongly_convex_factor=(1+tau*self.strong_convexity_constant)in_arr/=strongly_convex_factortau/=strongly_convex_factorsolution=self._fista_on_dual_rof(in_arr,tau)ifself.strong_convexity_constant>0:in_arr*=strongly_convex_factortau*=strongly_convex_factorreturnsolutiondef__rmul__(self,scalar):'''Define the multiplication with a scalar this changes the regularisation parameter in the plugin'''ifnotisinstance(scalar,Number):raiseNotImplementedelse:self.alpha*=scalarreturnselfdefcheck_input(self,input):iflen(input.shape)>3:raiseValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__,input.geometry.length))
[docs]def__init__(self,alpha=1,gamma=1,max_iteration=100,tolerance=0,device='cpu',**kwargs):'''Creator of Total Generalised Variation Function :param alpha: regularisation parameter :type alpha: number, default 1 :param gamma: ratio of TGV terms :type gamma: number, default 1, can range between 1 and 2 :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached :type max_iteration: integer, default 100 :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. :type tolerance: float, default 0 :param device: determines if the code runs on CPU or GPU :type device: string, default 'cpu', can be 'gpu' if GPU is installed '''self.alpha=alphaself.gamma=gammaself.max_iteration=max_iterationself.tolerance=toleranceself.device=deviceifkwargs.get('iter_TGV',None)isnotNone:# raise ValueError('iter_TGV parameter has been superseded by num_iter. Use that instead.')self.num_iter=kwargs.get('iter_TGV')
[docs]def__call__(self,x):warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))returnnp.nan
@propertydefgamma(self):returnself.__gamma@gamma.setterdefgamma(self,value):ifvalue<=2andvalue>=1:self.__gamma=value@propertydefalpha2(self):returnself.alpha1*self.gamma@propertydefalpha1(self):return1.defproximal_numpy(self,in_arr,tau):info=np.zeros((2,),dtype=np.float32)res=regularisers.TGV(in_arr,self.alpha*tau,self.alpha1,self.alpha2,self.max_iteration,self.LipshitzConstant,self.tolerance,infovector=info,device=self.device)# info: return number of iteration and reached tolerance# https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/master/src/Core/regularisers_CPU/TGV_core.c#L168# Stopping Criteria || u^k - u^(k-1) ||_{2} / || u^{k} ||_{2}returnres,info
[docs]defconvex_conjugate(self,x):warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))returnnp.nan
[docs]def__rmul__(self,scalar):'''Define the multiplication with a scalar this changes the regularisation parameter in the plugin'''ifnotisinstance(scalar,Number):raiseNotImplementedelse:self.alpha*=scalarreturnself
# f = TGV()# f = alpha * fdefcheck_input(self,input):iflen(input.shape)==2:self.LipshitzConstant=12eliflen(input.shape)==3:self.LipshitzConstant=16# Vaggelis to confirmelse:raiseValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__,input.geometry.length))
[docs]classFGP_dTV(RegulariserFunction):'''Creator of FGP_dTV Function :param reference: reference image :type reference: ImageData :param alpha: regularisation parameter :type alpha: number, default 1 :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached :type max_iteration: integer, default 100 :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. :type tolerance: float, default 0 :param eta: smoothing constant to calculate gradient of the reference :type eta: number, default 0.01 :param isotropic: Whether it uses L2 (isotropic) or L1 (anisotropic) norm :type isotropic: boolean, default True, can range between 1 and 2 :param nonnegativity: Whether to add the non-negativity constraint :type nonnegativity: boolean, default True :param device: determines if the code runs on CPU or GPU :type device: string, default 'cpu', can be 'gpu' if GPU is installed '''
[docs]def__init__(self,reference,alpha=1,max_iteration=100,tolerance=0,eta=0.01,isotropic=True,nonnegativity=True,device='cpu'):ifisotropic==True:self.methodTV=0else:self.methodTV=1ifnonnegativity==True:self.nonnegativity=1else:self.nonnegativity=0self.alpha=alphaself.max_iteration=max_iterationself.tolerance=toleranceself.device=device# string for 'cpu' or 'gpu'self.reference=np.asarray(reference.as_array(),dtype=np.float32)self.eta=eta
[docs]def__call__(self,x):warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))returnnp.nan
[docs]defconvex_conjugate(self,x):warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))returnnp.nan
[docs]def__rmul__(self,scalar):'''Define the multiplication with a scalar this changes the regularisation parameter in the plugin'''ifnotisinstance(scalar,Number):raiseNotImplementedelse:self.alpha*=scalarreturnself
defcheck_input(self,input):iflen(input.shape)>3:raiseValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__,input.geometry.length))
[docs]def__init__(self,alpha=1,max_iteration=100,tolerance=0):'''Creator of TNV Function :param alpha: regularisation parameter :type alpha: number, default 1 :param max_iteration: max number of sub iterations. The algorithm will iterate up to this number of iteration or up to when the tolerance has been reached :type max_iteration: integer, default 100 :param tolerance: minimum difference between previous iteration of the algorithm that determines the stop of the iteration earlier than max_iteration. If set to 0 only the max_iteration will be used as stop criterion. :type tolerance: float, default 0 '''# set parametersself.alpha=alphaself.max_iteration=max_iterationself.tolerance=tolerance
[docs]def__call__(self,x):warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))returnnp.nan
defproximal_numpy(self,in_arr,tau):# remove any dimension of size 1in_arr=np.squeeze(in_arr)res=regularisers.TNV(in_arr,self.alpha*tau,self.max_iteration,self.tolerance)returnres,[]
[docs]defconvex_conjugate(self,x):warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))returnnp.nan
[docs]def__rmul__(self,scalar):'''Define the multiplication with a scalar this changes the regularisation parameter in the plugin'''ifnotisinstance(scalar,Number):raiseNotImplementedelse:self.alpha*=scalarreturnself
[docs]defcheck_input(self,input):'''TNV requires 2D+channel data with the first dimension as the channel dimension'''ifisinstance(input,DataContainer):ImageDimension.check_order_for_engine('cil',input.geometry)if(input.geometry.channels==1)or(notinput.geometry.ndim==3):raiseValueError('TNV requires 2D+channel data. Got {}'.format(input.geometry.dimension_labels))else:# if it is not a CIL DataContainer we assume that the data is passed in the correct order# discard any dimension of size 1ifsum(1foriininput.shapeifi!=1)!=3:raiseValueError('TNV requires 3D data (with channel as first axis). Got {}'.format(input.shape))