Source code for cil.optimisation.functions.MixedL21Norm
# Copyright 2019 United Kingdom Research and Innovation# Copyright 2019 The University of Manchester## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.## Authors:# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txtfromcil.optimisation.functionsimportFunctionfromcil.frameworkimportBlockDataContainerimportnumpyasnpfromnumbersimportNumberhas_numba=Truetry:importnumba@numba.jit(parallel=True,nopython=True)def_proximal_step_numba(arr,abstau):'''Numba implementation of a step in the calculation of the proximal of MixedL21Norm Parameters: ----------- arr : numpy array, best if contiguous memory. abstau: float >= 0 Returns: -------- Stores the output in the input array. Note: ----- Input arr should be contiguous for best performance'''tmp=arr.ravel()foriinnumba.prange(tmp.size):iftmp[i]==0:continuea=tmp[i]/abstauel=a-1ifel<=0.0:el=0.tmp[i]=el/areturn0exceptImportError:has_numba=Falsedef_proximal_step_numpy(arr,tau):'''Numpy implementation of a step in the calculation of the proximal of MixedL21Norm Parameters: ----------- arr : DataContainer, best if contiguous memory. tau: float, numpy array or DataContainer Returns: -------- A DataContainer where we have substituted nan with 0. '''# Note: we divide x by tau so the cases of tau both scalar and# DataContainers runtry:tmp=np.abs(tau,dtype=np.float32)exceptnp.core._exceptions._UFuncInputCastingError:tmp=tau.abs()arr/=tmpres=arr-1res.maximum(0.0,out=res)res/=arrarr*=tmpresarray=res.as_array()resarray[np.isnan(resarray)]=0res.fill(resarray)returnres
[docs]classMixedL21Norm(Function):""" MixedL21Norm function: :math:`F(x) = ||x||_{2,1} = \sum |x|_{2} = \sum \sqrt{ (x^{1})^{2} + (x^{2})^{2} + \dots}` where x is a BlockDataContainer, i.e., :math:`x=(x^{1}, x^{2}, \dots)` """def__init__(self,**kwargs):super(MixedL21Norm,self).__init__()def__call__(self,x):r"""Returns the value of the MixedL21Norm function at x. :param x: :code:`BlockDataContainer` """ifnotisinstance(x,BlockDataContainer):raiseValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))returnx.pnorm(p=2).sum()
[docs]defconvex_conjugate(self,x):r"""Returns the value of the convex conjugate of the MixedL21Norm function at x. This is the Indicator function of :math:`\mathbb{I}_{\{\|\cdot\|_{2,\infty}\leq1\}}(x^{*})`, i.e., .. math:: \mathbb{I}_{\{\|\cdot\|_{2, \infty}\leq1\}}(x^{*}) = \begin{cases} 0, \mbox{if } \|x\|_{2, \infty}\leq1\\ \infty, \mbox{otherwise} \end{cases} where, .. math:: \|x\|_{2,\infty} = \max\{ \|x\|_{2} \} = \max\{ \sqrt{ (x^{1})^{2} + (x^{2})^{2} + \dots}\} """ifnotisinstance(x,BlockDataContainer):raiseValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))tmp=(x.pnorm(2).max()-1)iftmp<=1e-5:return0else:returnnp.inf
[docs]defproximal(self,x,tau,out=None):r"""Returns the value of the proximal operator of the MixedL21Norm function at x. .. math :: \mathrm{prox}_{\tau F}(x) = \frac{x}{\|x\|_{2}}\max\{ \|x\|_{2} - \tau, 0 \} where the convention 0 · (0/0) = 0 is used. """tmp=x.pnorm(2)ifhas_numbaandisinstance(tau,Number):try:# may involve a copy if the data is not contiguoustmparr=np.asarray(tmp.as_array(),order='C',dtype=tmp.dtype)if_proximal_step_numba(tmparr,np.abs(tau))!=0:# if numba silently crashesraiseRuntimeError('MixedL21Norm.proximal: numba silently crashed.')res=tmpres.fill(tmparr)except:res=_proximal_step_numpy(tmp,tau)else:res=_proximal_step_numpy(tmp,tau)returnx.multiply(res,out=out)
[docs]classSmoothMixedL21Norm(Function):""" SmoothMixedL21Norm function: :math:`F(x) = ||x||_{2,1} = \sum |x|_{2} = \sum \sqrt{ (x^{1})^{2} + (x^{2})^{2} + \epsilon^2 + \dots}` where x is a BlockDataContainer, i.e., :math:`x=(x^{1}, x^{2}, \dots)` Conjugate, proximal and proximal conjugate methods no closed-form solution """def__init__(self,epsilon):r''' :param epsilon: smoothing parameter making MixedL21Norm differentiable '''super(SmoothMixedL21Norm,self).__init__(L=1)self.epsilon=epsilonifself.epsilon==0:raiseValueError('We need epsilon>0. Otherwise, call "MixedL21Norm" ')def__call__(self,x):"""Returns the value of the SmoothMixedL21Norm function at x."""ifnotisinstance(x,BlockDataContainer):raiseValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))return(x.pnorm(2).power(2)+self.epsilon**2).sqrt().sum()
[docs]defgradient(self,x,out=None):r"""Returns the value of the gradient of the SmoothMixedL21Norm function at x. \frac{x}{|x|} """ifnotisinstance(x,BlockDataContainer):raiseValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))denom=(x.pnorm(2).power(2)+self.epsilon**2).sqrt()returnx.divide(denom,out=out)