Source code for cil.optimisation.algorithms.PD3O

#  Copyright 2024 United Kingdom Research and Innovation
#  Copyright 2024 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt


from cil.optimisation.algorithms import Algorithm
from cil.optimisation.functions import ZeroFunction
import logging
import warnings
[docs] class PD3O(Algorithm): r"""Primal Dual Three Operator Splitting (PD3O) algorithm, see "A New Primal–Dual Algorithm for Minimizing the Sum of Three Functions with a Linear Operator". This is a primal dual algorithm for minimising :math:`f(x)+g(x)+h(Ax)` where all functions are proper, lower semi-continuous and convex, :math:`f` should be differentiable with a Lipschitz continuous gradient and :math:`A` is a bounded linear operator. Parameters ---------- f : Function A smooth function with Lipschitz continuous gradient. g : Function A convex function with a computationally computable proximal. h : Function A composite convex function. operator: Operator Bounded linear operator delta: Float, optional, default is `1./(gamma*operator.norm()**2)` The dual step-size gamma: Float, optional, default is `2.0/f.L` The primal step size initial : DataContainer, optional default is a container of zeros, in the domain of the operator Initial point for the algorithm. Reference --------- Yan, M. A New Primal–Dual Algorithm for Minimizing the Sum of Three Functions with a Linear Operator. J Sci Comput 76, 1698–1717 (2018). https://doi.org/10.1007/s10915-018-0680-3 """ def __init__(self, f, g, h, operator, delta=None, gamma=None, initial=None, **kwargs): super(PD3O, self).__init__(**kwargs) self.set_up(f=f, g=g, h=h, operator=operator, delta=delta, gamma=gamma, initial=initial, **kwargs)
[docs] def set_up(self, f, g, h, operator, delta=None, gamma=None, initial=None,**kwargs): logging.info("{} setting up".format(self.__class__.__name__, )) self.f = f # smooth function if isinstance(self.f, ZeroFunction): warnings.warn(" If self.f is the ZeroFunction, then PD3O = PDHG. Please use PDHG instead. Otherwise, select a relatively small parameter gamma ", UserWarning) if gamma is None: gamma = 1.0/operator.norm() self.g = g # proximable self.h = h # composite self.operator = operator if gamma is None: gamma = 0.99*2.0/self.f.L if delta is None : delta = 0.99/(gamma*self.operator.norm()**2) self.gamma = gamma self.delta = delta if initial is None: self.x = self.operator.domain_geometry().allocate(0) else: self.x = initial.copy() self.x_old = self.x.copy() self.s_old = self.operator.range_geometry().allocate(0) self.s = self.operator.range_geometry().allocate(0) self.grad_f = self.operator.domain_geometry().allocate(0) self.configured = True logging.info("{} configured".format(self.__class__.__name__, )) # initial proximal conjugate step self.operator.direct(self.x_old, out=self.s) self.s_old.sapyb(1, self.s, self.delta, out=self.s_old) self.h.proximal_conjugate(self.s_old, self.delta, out=self.s)
[docs] def update(self): r""" Performs a single iteration of the PD3O algorithm """ # Following equations 4 in https://link.springer.com/article/10.1007/s10915-018-0680-3 # in this case order of proximal steps we recover the (primal) PDHG, when f=0 tmp = self.x_old self.x_old = self.x self.x = tmp # proximal step self.f.gradient(self.x_old, out=self.grad_f) self.x_old.sapyb(1., self.grad_f, -self.gamma, out = self.grad_f) # x_old - gamma * grad_f(x_old) self.operator.adjoint(self.s, out=self.x_old) self.x_old.sapyb(-self.gamma, self.grad_f, 1.0, out=self.x_old) self.g.proximal(self.x_old, self.gamma, out = self.x) # update step self.f.gradient(self.x, out=self.x_old) self.x_old *= self.gamma self.grad_f += self.x_old self.x.sapyb(2, self.grad_f, -1.0, out=self.x_old) # 2*x - x_old + gamma*(grad_f_x_old) - gamma*(grad_f_x) tmp = self.s_old self.s_old = self.s self.s = tmp # proximal conjugate step self.operator.direct(self.x_old, out=self.s) self.s_old.sapyb(1, self.s, self.delta, out=self.s_old) self.h.proximal_conjugate(self.s_old, self.delta, out=self.s)
[docs] def update_objective(self): """ Evaluates the primal objective """ self.operator.direct(self.x, out=self.s_old) fun_h = self.h(self.s_old) fun_g = self.g(self.x) fun_f = self.f(self.x) p1 = fun_f + fun_g + fun_h self.loss.append(p1)