# -*- coding: utf-8 -*-
# Copyright 2019 United Kingdom Research and Innovation
# Copyright 2019 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
import numpy
import functools
from cil.framework import ImageData, BlockDataContainer, DataContainer
from cil.optimisation.operators import Operator, LinearOperator
from cil.framework import BlockGeometry
try:
from sirf import SIRF
from sirf.SIRF import DataContainer as SIRFDataContainer
has_sirf = True
except ImportError as ie:
has_sirf = False
[docs]class BlockOperator(Operator):
r'''A Block matrix containing Operators
The Block Framework is a generic strategy to treat variational problems in the
following form:
.. math::
\min Regulariser + Fidelity
BlockOperators have a generic shape M x N, and when applied on an
Nx1 BlockDataContainer, will yield and Mx1 BlockDataContainer.
Notice: BlockDatacontainer are only allowed to have the shape of N x 1, with
N rows and 1 column.
User may specify the shape of the block, by default is a row vector
Operators in a Block are required to have the same domain column-wise and the
same range row-wise.
'''
__array_priority__ = 1
[docs] def __init__(self, *args, **kwargs):
'''
Class creator
Note:
Do not include the `self` parameter in the ``Args`` section.
Args:
:param: vararg (Operator): Operators in the block.
:param: shape (:obj:`tuple`, optional): If shape is passed the Operators in
vararg are considered input in a row-by-row fashion.
Shape and number of Operators must match.
Example:
BlockOperator(op0,op1) results in a row block
BlockOperator(op0,op1,shape=(1,2)) results in a column block
'''
self.operators = args
shape = kwargs.get('shape', None)
if shape is None:
shape = (len(args),1)
self.shape = shape
n_elements = functools.reduce(lambda x,y: x*y, shape, 1)
if len(args) != n_elements:
raise ValueError(
'Dimension and size do not match: expected {} got {}'
.format(n_elements,len(args)))
# TODO
# until a decent way to check equality of Acquisition/Image geometries
# required to fullfil "Operators in a Block are required to have the same
# domain column-wise and the same range row-wise."
# let us just not check if column/row-wise compatible, which is actually
# the same achieved by the column_wise_compatible and row_wise_compatible methods.
# # test if operators are compatible
# if not self.column_wise_compatible():
# raise ValueError('Operators in each column must have the same domain')
# if not self.row_wise_compatible():
# raise ValueError('Operators in each row must have the same range')
[docs] def column_wise_compatible(self):
'''Operators in a Block should have the same domain per column'''
rows, cols = self.shape
compatible = True
for col in range(cols):
column_compatible = True
for row in range(1,rows):
dg0 = self.get_item(row-1,col).domain_geometry()
dg1 = self.get_item(row,col).domain_geometry()
if hasattr(dg0,'handle') and hasattr(dg1,'handle'):
column_compatible = True and column_compatible
else:
column_compatible = dg0.__dict__ == dg1.__dict__ and column_compatible
compatible = compatible and column_compatible
return compatible
[docs] def row_wise_compatible(self):
'''Operators in a Block should have the same range per row'''
rows, cols = self.shape
compatible = True
for row in range(rows):
row_compatible = True
for col in range(1,cols):
dg0 = self.get_item(row,col-1).range_geometry()
dg1 = self.get_item(row,col).range_geometry()
if hasattr(dg0,'handle') and hasattr(dg1,'handle'):
row_compatible = True and column_compatible
else:
row_compatible = dg0.__dict__ == dg1.__dict__ and row_compatible
compatible = compatible and row_compatible
return compatible
[docs] def get_item(self, row, col):
'''returns the Operator at specified row and col'''
if row > self.shape[0]:
raise ValueError('Requested row {} > max {}'.format(row, self.shape[0]))
if col > self.shape[1]:
raise ValueError('Requested col {} > max {}'.format(col, self.shape[1]))
index = row*self.shape[1]+col
return self.operators[index]
[docs] def norm(self, **kwargs):
'''Returns the norm of the BlockOperator
if the operator in the block do not have method norm defined, i.e. they are SIRF
AcquisitionModel's we use PowerMethod if applicable, otherwise we raise an Error
'''
norm = []
for op in self.operators:
if hasattr(op, 'norm'):
norm.append(op.norm(**kwargs) ** 2.)
else:
# use Power method
if op.is_linear():
norm.append(
LinearOperator.PowerMethod(op, 20)[0]
)
else:
raise TypeError('Operator {} does not have a norm method and is not linear'.format(op))
return numpy.sqrt(sum(norm))
[docs] def direct(self, x, out=None):
'''Direct operation for the BlockOperator
BlockOperator work on BlockDataContainer, but they will work on DataContainers
and inherited classes by simple wrapping the input in a BlockDataContainer of shape (1,1)
'''
if not isinstance (x, BlockDataContainer):
x_b = BlockDataContainer(x)
else:
x_b = x
shape = self.get_output_shape(x_b.shape)
res = []
if out is None:
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
prod = self.get_item(row,col).direct(x_b.get_item(col))
else:
prod += self.get_item(row,col).direct(x_b.get_item(col))
res.append(prod)
return BlockDataContainer(*res, shape=shape)
else:
tmp = self.range_geometry().allocate()
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
self.get_item(row,col).direct(
x_b.get_item(col),
out=out.get_item(row))
else:
a = out.get_item(row)
self.get_item(row,col).direct(
x_b.get_item(col),
out=tmp.get_item(row))
a += tmp.get_item(row)
[docs] def adjoint(self, x, out=None):
'''Adjoint operation for the BlockOperator
BlockOperator may contain both LinearOperator and Operator
This method exists in BlockOperator as it is not known what type of
Operator it will contain.
BlockOperator work on BlockDataContainer, but they will work on DataContainers
and inherited classes by simple wrapping the input in a BlockDataContainer of shape (1,1)
Raises: ValueError if the contained Operators are not linear
'''
if not self.is_linear():
raise ValueError('Not all operators in Block are linear.')
if not isinstance (x, BlockDataContainer):
x_b = BlockDataContainer(x)
else:
x_b = x
shape = self.get_output_shape(x_b.shape, adjoint=True)
if out is None:
res = []
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
prod = self.get_item(row, col).adjoint(x_b.get_item(row))
else:
prod += self.get_item(row, col).adjoint(x_b.get_item(row))
res.append(prod)
if self.shape[1]==1:
# the output is a single DataContainer, so we can take it out
return res[0]
else:
return BlockDataContainer(*res, shape=shape)
else:
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
if issubclass(out.__class__, DataContainer) or \
( has_sirf and issubclass(out.__class__, SIRFDataContainer) ):
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=out)
else:
op = self.get_item(row,col)
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=out.get_item(col))
else:
if issubclass(out.__class__, DataContainer) or \
( has_sirf and issubclass(out.__class__, SIRFDataContainer) ):
out += self.get_item(row,col).adjoint(
x_b.get_item(row))
else:
a = out.get_item(col)
a += self.get_item(row,col).adjoint(
x_b.get_item(row),
)
[docs] def is_linear(self):
'''returns whether all the elements of the BlockOperator are linear'''
return functools.reduce(lambda x, y: x and y.is_linear(), self.operators, True)
[docs] def get_output_shape(self, xshape, adjoint=False):
'''returns the shape of the output BlockDataContainer
A(N,M) direct u(M,1) -> N,1
A(N,M)^T adjoint u(N,1) -> M,1
'''
rows , cols = self.shape
xrows, xcols = xshape
if xcols != 1:
raise ValueError('BlockDataContainer cannot have more than 1 column')
if adjoint:
if rows != xrows:
raise ValueError('Incompatible shapes {} {}'.format(self.shape, xshape))
return (cols,xcols)
if cols != xrows:
raise ValueError('Incompatible shapes {} {}'.format((rows,cols), xshape))
return (rows,xcols)
[docs] def __rmul__(self, scalar):
'''Defines the left multiplication with a scalar
:paramer scalar: (number or iterable containing numbers):
Returns: a block operator with Scaled Operators inside'''
if isinstance (scalar, list) or isinstance(scalar, tuple) or \
isinstance(scalar, numpy.ndarray):
if len(scalar) != len(self.operators):
raise ValueError('dimensions of scalars and operators do not match')
scalars = scalar
else:
scalars = [scalar for _ in self.operators]
# create a list of ScaledOperator-s
ops = [ v * op for v,op in zip(scalars, self.operators)]
#return BlockScaledOperator(self, scalars ,shape=self.shape)
return type(self)(*ops, shape=self.shape)
@property
def T(self):
'''Return the transposed of self
input in a row-by-row'''
newshape = (self.shape[1], self.shape[0])
oplist = []
for col in range(newshape[1]):
for row in range(newshape[0]):
oplist.append(self.get_item(col,row))
return type(self)(*oplist, shape=newshape)
[docs] def domain_geometry(self):
'''returns the domain of the BlockOperator
If the shape of the BlockOperator is (N,1) the domain is a ImageGeometry or AcquisitionGeometry.
Otherwise it is a BlockGeometry.
'''
if self.shape[1] == 1:
# column BlockOperator
return self.get_item(0,0).domain_geometry()
else:
# get the geometries column wise
# we need only the geometries from the first row
# since it is compatible from __init__
tmp = []
for i in range(self.shape[1]):
tmp.append(self.get_item(0,i).domain_geometry())
return BlockGeometry(*tmp)
#shape = (self.shape[0], 1)
#return BlockGeometry(*[el.domain_geometry() for el in self.operators],
# shape=self.shape)
[docs] def range_geometry(self):
'''returns the range of the BlockOperator'''
tmp = []
for i in range(self.shape[0]):
tmp.append(self.get_item(i,0).range_geometry())
return BlockGeometry(*tmp)
#shape = (self.shape[1], 1)
#return BlockGeometry(*[el.range_geometry() for el in self.operators],
# shape=shape)
def sum_abs_row(self):
res = []
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
prod = self.get_item(row,col).sum_abs_row()
else:
prod += self.get_item(row,col).sum_abs_row()
res.append(prod)
if self.shape[1]==1:
tmp = sum(res)
return ImageData(tmp)
else:
return BlockDataContainer(*res)
def sum_abs_col(self):
res = []
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
prod = self.get_item(row, col).sum_abs_col()
else:
prod += self.get_item(row, col).sum_abs_col()
res.append(prod)
return BlockDataContainer(*res)
def __len__(self):
return len(self.operators)
[docs] def __getitem__(self, index):
'''returns the index-th operator in the block irrespectively of it's shape'''
return self.operators[index]
[docs] def get_as_list(self):
'''returns the list of operators'''
return self.operators