# Copyright 2019 United Kingdom Research and Innovation
# Copyright 2019 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
import numpy
import functools
from numbers import Number
from cil.framework import ImageData, BlockDataContainer, DataContainer
from cil.optimisation.operators import Operator, LinearOperator
from cil.framework import BlockGeometry
try:
from sirf import SIRF
from sirf.SIRF import DataContainer as SIRFDataContainer
has_sirf = True
except ImportError as ie:
has_sirf = False
[docs]
class BlockOperator(Operator):
r'''A Block matrix containing Operators
Parameters
----------
*args : Operator
Operators in the block.
**kwargs : dict
shape (:obj:`tuple`, optional): If shape is passed the Operators in vararg are considered input in a row-by-row fashion.
Note
----
The Block Framework is a generic strategy to treat variational problems in the
following form:
.. math::
\min Regulariser + Fidelity
BlockOperators have a generic shape M x N, and when applied on an
Nx1 BlockDataContainer, will yield and Mx1 BlockDataContainer.
Note
-----
BlockDatacontainer are only allowed to have the shape of N x 1, with
N rows and 1 column.
User may specify the shape of the block, by default is a row vector
Operators in a Block are required to have the same domain column-wise and the
same range row-wise.
Examples
-------
BlockOperator(op0,op1) results in a row block
BlockOperator(op0,op1,shape=(1,2)) results in a column block
'''
__array_priority__ = 1
[docs]
def __init__(self, *args, **kwargs):
self.operators = args
shape = kwargs.get('shape', None)
if shape is None:
shape = (len(args), 1)
self.shape = shape
n_elements = functools.reduce(lambda x, y: x*y, shape, 1)
if len(args) != n_elements:
raise ValueError(
'Dimension and size do not match: expected {} got {}'
.format(n_elements, len(args)))
self._range_block_shape = (shape[0], 1)
self._domain_block_shape = (shape[1], 1)
# TODO
# until a decent way to check equality of Acquisition/Image geometries
# required to fullfil "Operators in a Block are required to have the same
# domain column-wise and the same range row-wise."
# let us just not check if column/row-wise compatible, which is actually
# the same achieved by the column_wise_compatible and row_wise_compatible methods.
# # test if operators are compatible
# if not self.column_wise_compatible():
# raise ValueError('Operators in each column must have the same domain')
# if not self.row_wise_compatible():
# raise ValueError('Operators in each row must have the same range')
[docs]
def column_wise_compatible(self):
'''Operators in a Block should have the same domain per column'''
rows, cols = self.shape
compatible = True
for col in range(cols):
column_compatible = True
for row in range(1, rows):
dg0 = self.get_item(row-1, col).domain_geometry()
dg1 = self.get_item(row, col).domain_geometry()
if hasattr(dg0, 'handle') and hasattr(dg1, 'handle'):
column_compatible = True and column_compatible
else:
column_compatible = dg0.__dict__ == dg1.__dict__ and column_compatible
compatible = compatible and column_compatible
return compatible
[docs]
def row_wise_compatible(self):
'''Operators in a Block should have the same range per row'''
rows, cols = self.shape
compatible = True
for row in range(rows):
row_compatible = True
for col in range(1, cols):
dg0 = self.get_item(row, col-1).range_geometry()
dg1 = self.get_item(row, col).range_geometry()
if hasattr(dg0, 'handle') and hasattr(dg1, 'handle'):
row_compatible = True and column_compatible
else:
row_compatible = dg0.__dict__ == dg1.__dict__ and row_compatible
compatible = compatible and row_compatible
return compatible
[docs]
def get_item(self, row, col):
'''Returns the Operator at specified row and col
Parameters
----------
row: `int`
The row index required.
col: `int`
The column index required.
'''
if row > self.shape[0]:
raise ValueError(
'Requested row {} > max {}'.format(row, self.shape[0]))
if col > self.shape[1]:
raise ValueError(
'Requested col {} > max {}'.format(col, self.shape[1]))
index = row*self.shape[1]+col
return self.operators[index]
[docs]
def norm(self):
'''Returns the Euclidean norm of the norms of the individual operators in the BlockOperators '''
return numpy.sqrt(numpy.sum(numpy.array(self.get_norms_as_list())**2))
[docs]
def get_norms_as_list(self, ):
'''Returns a list of the individual norms of the Operators in the BlockOperator
'''
return [op.norm() for op in self.operators]
[docs]
def set_norms(self, norms):
'''Uses the set_norm() function in Operator to set the norms of the operators in the BlockOperator from a list of custom values.
Parameters
------------
norms: list
A list of positive real values the same length as the number of operators in the BlockOperator.
'''
if len(norms) != self.size:
raise ValueError(
"The length of the list of norms should be equal to the number of operators in the BlockOperator")
for j, value in enumerate(norms):
self.operators[j].set_norm(value)
[docs]
def direct(self, x, out=None):
'''Direct operation for the BlockOperator
Parameters
----------
x: BlockDataContainer
The input BlockDataContainer to apply the BlockOperator on. Can be a DataContainer if the domain geometry permits.
out: BlockDataContainer, optional
The output BlockDataContainer to store the result of the operation. If not provided, a new BlockDataContainer is created. Can be a DataContainer if the range geometry permits.
Note
-----
BlockOperators work on BlockDataContainers, but they will also work on DataContainers
and inherited classes by simple wrapping the input in a BlockDataContainer of shape (1,1)
'''
if not isinstance(x, BlockDataContainer):
x_b = BlockDataContainer(x)
else:
x_b = x
if x_b.shape != self._domain_block_shape:
raise ValueError(
'We expect the input to be a block data container of shape {}'.format( self._domain_block_shape))
unwrap_data_container_on_return = False
if self._range_block_shape == (1,1):
unwrap_data_container_on_return = True
if out is None:
# allocate the output blockdatacontainer of the correct shape
res = BlockDataContainer(*[self.get_item(row, 0).range_geometry().allocate(None)
for row in range(self.shape[0])], shape=self._range_block_shape)
elif not isinstance(out, BlockDataContainer):
# Handle datacontainers or sirf datacontainers
if unwrap_data_container_on_return:
res = BlockDataContainer(out)
else:
raise ValueError(
f'The range of this block operator is not compatible with the `out` that was passed. Expected `out` to be `None` or a `BlockDataContainer` of shape {self._range_block_shape}')
else:
res = out
unwrap_data_container_on_return = False
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
self.get_item(row, col).direct(x_b.get_item(col), out=res.get_item(row))
else:
# temp_out_row points to the element in res that we are adding to
temp_out_row = res.get_item(row)
temp_out_row += self.get_item(row, col).direct(x_b.get_item(col))
if unwrap_data_container_on_return:
# Return the out as the user passed it in case the range shape is (1,1)
return res.get_item(0)
else:
return res
[docs]
def adjoint(self, x, out=None):
'''Adjoint operation for the BlockOperator
Parameters
----------
x: BlockDataContainer
The input BlockDataContainer to apply the BlockOperator adjoint on. Can be a DataContainer if the range geometry permits.
out: BlockDataContainer, optional
The output BlockDataContainer to store the result of the operation. If not provided, a new BlockDataContainer is created. Can be a DataContainer if the domain geometry permits.
Note
-----
BlockOperator may contain both LinearOperator and Operator
This method exists in BlockOperator as it is not known what type of
Operator it will contain.
BlockOperators work on BlockDataContainers, but they will also work on DataContainers
and inherited classes by simple wrapping the input in a BlockDataContainer of shape (1,1)
Raises: ValueError if the contained Operators are not linear
'''
if not self.is_linear():
raise ValueError('Not all operators in Block are linear.')
if not isinstance(x, BlockDataContainer):
x_b = BlockDataContainer(x)
else:
x_b = x
if x_b.shape != self._range_block_shape:
raise ValueError(
'We expect the input to be a block data container of shape {}'.format( self._range_block_shape))
unwrap_data_container_on_return = False
if self._domain_block_shape == (1,1):
unwrap_data_container_on_return = True
if out is None:
# allocate the output blockdatacontainer of the correct shape
res = BlockDataContainer(*[self.get_item(0, col).domain_geometry().allocate(0)
for col in range(self.shape[1])], shape=self._domain_block_shape)
elif not isinstance(out, BlockDataContainer):
# Handle datacontainers or sirf datacontainers
if unwrap_data_container_on_return:
res = BlockDataContainer(out)
else:
raise ValueError(
f'The domain of this block operator is not compatible with the `out` that was passed. Expected `out` to be `None` or a `BlockDataContainer` of shape {self._domain_block_shape}')
else:
res = out
unwrap_data_container_on_return = False
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=res.get_item(col))
else:
# out_col_operator points to the column in res that we are updating
temp_out_col = res.get_item(col)
temp_out_col += self.get_item(row, col).adjoint(
x_b.get_item(row),
)
if unwrap_data_container_on_return:
# Return the out as the user passed it in case the range shape is (1,1)
return res.get_item(0)
else:
return res
[docs]
def is_linear(self):
'''Returns whether all the elements of the BlockOperator are linear'''
return functools.reduce(lambda x, y: x and y.is_linear(), self.operators, True)
[docs]
def get_output_shape(self, xshape, adjoint=False):
'''Returns the shape of the output BlockDataContainer
Parameters
----------
xshape: BlockDataContainer
adjoint: `bool`
Examples
--------
A(N,M) direct u(M,1) -> N,1
A(N,M)^T adjoint u(N,1) -> M,1
'''
rows, cols = self.shape
xrows, xcols = xshape
if xcols != 1:
raise ValueError(
'BlockDataContainer cannot have more than 1 column')
if adjoint:
if rows != xrows:
raise ValueError(
'Incompatible shapes {} {}'.format(self.shape, xshape))
return (cols, xcols)
if cols != xrows:
raise ValueError(
'Incompatible shapes {} {}'.format((rows, cols), xshape))
return (rows, xcols)
[docs]
def __rmul__(self, scalar):
'''Defines the left multiplication with a scalar. Returns a block operator with Scaled Operators inside.
Parameters
------------
scalar: number or iterable containing numbers
'''
if isinstance(scalar, list) or isinstance(scalar, tuple) or \
isinstance(scalar, numpy.ndarray):
if len(scalar) != len(self.operators):
raise ValueError(
'dimensions of scalars and operators do not match')
scalars = scalar
else:
scalars = [scalar for _ in self.operators]
# create a list of ScaledOperator-s
ops = [v * op for v, op in zip(scalars, self.operators)]
# return BlockScaledOperator(self, scalars ,shape=self.shape)
return type(self)(*ops, shape=self.shape)
@property
def T(self):
'''Returns the transposed of self.
Recall the input list is shaped in a row-by-row fashion'''
newshape = (self.shape[1], self.shape[0])
oplist = []
for col in range(newshape[1]):
for row in range(newshape[0]):
oplist.append(self.get_item(col, row))
return type(self)(*oplist, shape=newshape)
[docs]
def domain_geometry(self):
'''Returns the domain of the BlockOperator
If the shape of the BlockOperator is (N,1) the domain is a ImageGeometry or AcquisitionGeometry.
Otherwise it is a BlockGeometry.
'''
if self.shape[1] == 1:
# column BlockOperator
return self.get_item(0, 0).domain_geometry()
else:
# get the geometries column wise
# we need only the geometries from the first row
# since it is compatible from __init__
tmp = []
for i in range(self.shape[1]):
tmp.append(self.get_item(0, i).domain_geometry())
if self.shape[1] == 1:
return tmp[0]
return BlockGeometry(*tmp)
# shape = (self.shape[0], 1)
# return BlockGeometry(*[el.domain_geometry() for el in self.operators],
# shape=self.shape)
[docs]
def range_geometry(self):
'''Returns the range of the BlockOperator'''
tmp = []
for i in range(self.shape[0]):
tmp.append(self.get_item(i, 0).range_geometry())
if self.shape[0] == 1:
return tmp[0]
return BlockGeometry(*tmp)
def sum_abs_row(self):
res = []
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
prod = self.get_item(row, col).sum_abs_row()
else:
prod += self.get_item(row, col).sum_abs_row()
res.append(prod)
if self.shape[1] == 1:
tmp = sum(res)
return ImageData(tmp)
else:
return BlockDataContainer(*res)
def sum_abs_col(self):
res = []
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
prod = self.get_item(row, col).sum_abs_col()
else:
prod += self.get_item(row, col).sum_abs_col()
res.append(prod)
return BlockDataContainer(*res)
def __len__(self):
return len(self.operators)
@property
def size(self):
return len(self.operators)
[docs]
def __getitem__(self, index):
'''Returns the index-th operator in the block irrespectively of it's shape'''
return self.operators[index]
[docs]
def get_as_list(self):
'''Returns the list of operators'''
return self.operators