# Copyright 2019 United Kingdom Research and Innovation# Copyright 2019 The University of Manchester## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.## Authors:# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txtimportnumpyfromnumbersimportNumberimportfunctoolsfromcil.utilities.multiprocessingimportNUM_THREADS
[docs]classBlockDataContainer(object):'''Class to hold DataContainers as column vector Provides basic algebra between BlockDataContainer's, DataContainer's and subclasses and Numbers 1) algebra between `BlockDataContainer`s will be element-wise, only if the shape of the 2 `BlockDataContainer`s is the same, otherwise it will fail 2) algebra between `BlockDataContainer`s and `list` or `numpy array` will work as long as the number of `rows` and element of the arrays match, independently on the fact that the `BlockDataContainer` could be nested 3) algebra between `BlockDataContainer` and one `DataContainer` is possible. It will require all the `DataContainers` in the block to be compatible with the `DataContainer` we want to operate with. 4) algebra between `BlockDataContainer` and a `Number` is possible and it will be done with each element of the `BlockDataContainer` even if nested A = [ [B,C] , D] A * 3 = [ 3 * [B,C] , 3* D] = [ [ 3*B, 3*C] , 3*D ] '''ADD='add'SUBTRACT='subtract'MULTIPLY='multiply'DIVIDE='divide'POWER='power'SAPYB='sapyb'MAXIMUM='maximum'MINIMUM='minimum'ABS='abs'SIGN='sign'SQRT='sqrt'CONJUGATE='conjugate'__array_priority__=1__container_priority__=2@propertydefdtype(self):returntuple(i.dtypeforiinself.containers)def__init__(self,*args,**kwargs):''''''self.containers=argsself.index=0self.geometry=None#if len(set([i.shape for i in self.containers])):# self.geometry = self.containers[0].geometryshape=kwargs.get('shape',None)ifshapeisNone:shape=(len(args),1)# shape = (len(args),1)self.shape=shapen_elements=functools.reduce(lambdax,y:x*y,shape,1)iflen(args)!=n_elements:raiseValueError('Dimension and size do not match: expected {} got {}'.format(n_elements,len(args)))
[docs]def__iter__(self):'''BlockDataContainer is Iterable'''self.index=0returnself
[docs]defis_compatible(self,other):'''basic check if the size of the 2 objects fit'''ifisinstance(other,Number):returnTrueelifisinstance(other,(list,tuple,numpy.ndarray)):forotinother:ifnotisinstance(ot,Number):raiseValueError('List/ numpy array can only contain numbers {}'\
.format(type(ot)))returnlen(self.containers)==len(other)elifisinstance(other,BlockDataContainer):returnlen(self.containers)==len(other.containers)else:# this should work for other as DataContainers and childrenret=Truefori,elinenumerate(self.containers):ifisinstance(el,BlockDataContainer):a=el.is_compatible(other)else:a=el.shape==other.shaperet=retanda# probably will raisereturnret
defget_item(self,row):ifrow>self.shape[0]:raiseValueError('Requested row {} > max {}'.format(row,self.shape[0]))returnself.containers[row]def__getitem__(self,row):returnself.get_item(row)
[docs]defadd(self,other,*args,**kwargs):'''Algebra: add method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer :param: out (optional): provides a placehold for the resul. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.ADD,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.ADD,other,*args,**kwargs)
[docs]defsubtract(self,other,*args,**kwargs):'''Algebra: subtract method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer :param: out (optional): provides a placeholder for the result. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.SUBTRACT,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.SUBTRACT,other,*args,**kwargs)
[docs]defmultiply(self,other,*args,**kwargs):'''Algebra: multiply method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer) :param: out (optional): provides a placeholder for the result. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.MULTIPLY,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.MULTIPLY,other,*args,**kwargs)
[docs]defdivide(self,other,*args,**kwargs):'''Algebra: divide method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer) :param: out (optional): provides a placeholder for the result. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.DIVIDE,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.DIVIDE,other,*args,**kwargs)
[docs]defpower(self,other,*args,**kwargs):'''Algebra: power method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer :param: out (optional): provides a placeholder for the result. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.POWER,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.POWER,other,*args,**kwargs)
[docs]defmaximum(self,other,*args,**kwargs):'''Algebra: power method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer) :param: out (optional): provides a placeholder for the result. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.MAXIMUM,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.MAXIMUM,other,*args,**kwargs)
[docs]defminimum(self,other,*args,**kwargs):'''Algebra: power method of BlockDataContainer with number/DataContainer or BlockDataContainer :param: other (number, DataContainer or subclasses or BlockDataContainer) :param: out (optional): provides a placeholder for the result. '''out=kwargs.get('out',None)ifoutisnotNone:self.binary_operations(BlockDataContainer.MINIMUM,other,*args,**kwargs)else:returnself.binary_operations(BlockDataContainer.MINIMUM,other,*args,**kwargs)
[docs]defsapyb(self,a,y,b,out,num_threads=NUM_THREADS):r'''performs axpby element-wise on the BlockDataContainer containers Does the operation .. math:: a*x+b*y and stores the result in out, where x is self :param a: scalar :param b: scalar :param y: compatible (Block)DataContainer :param out: (Block)DataContainer to store the result Example: -------- >>> a = 2 >>> b = 3 >>> ig = ImageGeometry(10,11) >>> x = ig.allocate(1) >>> y = ig.allocate(2) >>> bdc1 = BlockDataContainer(2*x, y) >>> bdc2 = BlockDataContainer(x, 2*y) >>> out = bdc1.sapyb(a,bdc2,b) '''ifoutisNone:raiseValueError("out container cannot be None")kwargs={'a':a,'b':b,'out':out,'num_threads':NUM_THREADS}self.binary_operations(BlockDataContainer.SAPYB,y,**kwargs)
[docs]defaxpby(self,a,b,y,out,dtype=numpy.float32,num_threads=NUM_THREADS):'''Deprecated method. Alias of sapyb'''returnself.sapyb(a,y,b,out,num_threads)
[docs]defbinary_operations(self,operation,other,*args,**kwargs):'''Algebra: generic method of algebric operation with BlockDataContainer with number/DataContainer or BlockDataContainer Provides commutativity with DataContainer and subclasses, i.e. this class's reverse algebraic methods take precedence w.r.t. direct algebraic methods of DataContainer and subclasses. This method is not to be used directly '''ifnotself.is_compatible(other):raiseValueError('Incompatible for operation {}'.format(operation))out=kwargs.get('out',None)ifisinstance(other,Number):# try to do algebra with one DataContainer. Will raise error if not compatiblekw=kwargs.copy()res=[]fori,elinenumerate(self.containers):ifoperation==BlockDataContainer.ADD:op=el.addelifoperation==BlockDataContainer.SUBTRACT:op=el.subtractelifoperation==BlockDataContainer.MULTIPLY:op=el.multiplyelifoperation==BlockDataContainer.DIVIDE:op=el.divideelifoperation==BlockDataContainer.POWER:op=el.powerelifoperation==BlockDataContainer.MAXIMUM:op=el.maximumelifoperation==BlockDataContainer.MINIMUM:op=el.minimumelse:raiseValueError('Unsupported operation',operation)ifoutisnotNone:kw['out']=out.get_item(i)op(other,*args,**kw)else:res.append(op(other,*args,**kw))ifoutisnotNone:returnelse:returntype(self)(*res,shape=self.shape)elifisinstance(other,(list,tuple,numpy.ndarray,BlockDataContainer)):kw=kwargs.copy()res=[]ifisinstance(other,BlockDataContainer):the_other=other.containerselse:the_other=otherfori,zelinenumerate(zip(self.containers,the_other)):el=zel[0]ot=zel[1]ifoperation==BlockDataContainer.ADD:op=el.addelifoperation==BlockDataContainer.SUBTRACT:op=el.subtractelifoperation==BlockDataContainer.MULTIPLY:op=el.multiplyelifoperation==BlockDataContainer.DIVIDE:op=el.divideelifoperation==BlockDataContainer.POWER:op=el.powerelifoperation==BlockDataContainer.MAXIMUM:op=el.maximumelifoperation==BlockDataContainer.MINIMUM:op=el.minimumelifoperation==BlockDataContainer.SAPYB:ifnotisinstance(other,BlockDataContainer):raiseValueError("{} cannot handle {}".format(operation,type(other)))op=el.sapybelse:raiseValueError('Unsupported operation',operation)ifoutisnotNone:ifoperation==BlockDataContainer.SAPYB:ifisinstance(kw['a'],BlockDataContainer):a=kw['a'].get_item(i)else:a=kw['a']ifisinstance(kw['b'],BlockDataContainer):b=kw['b'].get_item(i)else:b=kw['b']el.sapyb(a,ot,b,out.get_item(i),num_threads=kw['num_threads'])else:kw['out']=out.get_item(i)op(ot,*args,**kw)else:res.append(op(ot,*args,**kw))ifoutisnotNone:returnelse:returntype(self)(*res,shape=self.shape)else:# try to do algebra with one DataContainer. Will raise error if not compatiblekw=kwargs.copy()ifoperation!=BlockDataContainer.SAPYB:# remove keyworded argument related to SAPYBforkin['a','b','y','num_threads','dtype']:ifkinkw.keys():kw.pop(k)res=[]fori,elinenumerate(self.containers):ifoperation==BlockDataContainer.ADD:op=el.addelifoperation==BlockDataContainer.SUBTRACT:op=el.subtractelifoperation==BlockDataContainer.MULTIPLY:op=el.multiplyelifoperation==BlockDataContainer.DIVIDE:op=el.divideelifoperation==BlockDataContainer.POWER:op=el.powerelifoperation==BlockDataContainer.MAXIMUM:op=el.maximumelifoperation==BlockDataContainer.MINIMUM:op=el.minimumelifoperation==BlockDataContainer.SAPYB:ifisinstance(kw['a'],BlockDataContainer):a=kw['a'].get_item(i)else:a=kw['a']ifisinstance(kw['b'],BlockDataContainer):b=kw['b'].get_item(i)else:b=kw['b']el.sapyb(a,other,b,out.get_item(i),kw['num_threads'])# As axpyb cannot return anything we `continue` to skip the rest of the code blockcontinueelse:raiseValueError('Unsupported operation',operation)ifoutisnotNone:kw['out']=out.get_item(i)op(other,*args,**kw)else:res.append(op(other,*args,**kw))ifoutisnotNone:returnelse:returntype(self)(*res,shape=self.shape)
## unary operations
[docs]defunary_operations(self,operation,*args,**kwargs):'''Unary operation on BlockDataContainer: generic method of unary operation with BlockDataContainer: abs, sign, sqrt and conjugate This method is not to be used directly '''out=kwargs.get('out',None)kw=kwargs.copy()ifoutisNone:res=[]forelinself.containers:ifoperation==BlockDataContainer.ABS:op=el.abselifoperation==BlockDataContainer.SIGN:op=el.signelifoperation==BlockDataContainer.SQRT:op=el.sqrtelifoperation==BlockDataContainer.CONJUGATE:op=el.conjugateres.append(op(*args,**kw))returnBlockDataContainer(*res)else:kw.pop('out')forel,eloutinzip(self.containers,out.containers):ifoperation==BlockDataContainer.ABS:op=el.abselifoperation==BlockDataContainer.SIGN:op=el.signelifoperation==BlockDataContainer.SQRT:op=el.sqrtelifoperation==BlockDataContainer.CONJUGATE:op=el.conjugatekw['out']=eloutop(*args,**kw)
defabs(self,*args,**kwargs):returnself.unary_operations(BlockDataContainer.ABS,*args,**kwargs)defsign(self,*args,**kwargs):returnself.unary_operations(BlockDataContainer.SIGN,*args,**kwargs)defsqrt(self,*args,**kwargs):returnself.unary_operations(BlockDataContainer.SQRT,*args,**kwargs)defconjugate(self,*args,**kwargs):returnself.unary_operations(BlockDataContainer.CONJUGATE,*args,**kwargs)# def abs(self, *args, **kwargs):# return type(self)(*[ el.abs(*args, **kwargs) for el in self.containers], shape=self.shape)# def sign(self, *args, **kwargs):# return type(self)(*[ el.sign(*args, **kwargs) for el in self.containers], shape=self.shape)# def sqrt(self, *args, **kwargs):# return type(self)(*[ el.sqrt(*args, **kwargs) for el in self.containers], shape=self.shape)# def conjugate(self, out=None):# return type(self)(*[el.conjugate() for el in self.containers], shape=self.shape)## reductionsdefsum(self,*args,**kwargs):returnnumpy.sum([el.sum(*args,**kwargs)forelinself.containers])defsquared_norm(self):y=numpy.asarray([el.squared_norm()forelinself.containers])returny.sum()defnorm(self):returnnumpy.sqrt(self.squared_norm())defpnorm(self,p=2):ifp==1:returnsum(self.abs())elifp==2:tmp=functools.reduce(lambdaa,b:a+b.conjugate()*b,self.containers,self.get_item(0)*0).sqrt()returntmpelse:returnValueError('Not implemented')
[docs]defcopy(self):'''alias of clone'''returnself.clone()
defclone(self):returntype(self)(*[el.copy()forelinself.containers],shape=self.shape)deffill(self,other):ifisinstance(other,BlockDataContainer):ifnotself.is_compatible(other):raiseValueError('Incompatible containers')forel,otinzip(self.containers,other.containers):el.fill(ot)else:returnValueError('Cannot fill with object provided {}'.format(type(other)))def__add__(self,other):returnself.add(other)# __radd__def__sub__(self,other):returnself.subtract(other)# __rsub__def__mul__(self,other):returnself.multiply(other)# __rmul__def__div__(self,other):returnself.divide(other)# __rdiv__def__truediv__(self,other):returnself.divide(other)def__pow__(self,other):returnself.power(other)# reverse operand
[docs]def__radd__(self,other):'''Reverse addition to make sure that this method is called rather than the __mul__ of a numpy array the class constant __array_priority__ must be set > 0 https://docs.scipy.org/doc/numpy-1.15.1/reference/arrays.classes.html#numpy.class.__array_priority__ '''returnself+other
# __radd__
[docs]def__rsub__(self,other):'''Reverse subtraction to make sure that this method is called rather than the __mul__ of a numpy array the class constant __array_priority__ must be set > 0 https://docs.scipy.org/doc/numpy-1.15.1/reference/arrays.classes.html#numpy.class.__array_priority__ '''return(-1*self)+other
# __rsub__
[docs]def__rmul__(self,other):'''Reverse multiplication to make sure that this method is called rather than the __mul__ of a numpy array the class constant __array_priority__ must be set > 0 https://docs.scipy.org/doc/numpy-1.15.1/reference/arrays.classes.html#numpy.class.__array_priority__ '''returnself*other
# __rmul__
[docs]def__rdiv__(self,other):'''Reverse division to make sure that this method is called rather than the __mul__ of a numpy array the class constant __array_priority__ must be set > 0 https://docs.scipy.org/doc/numpy-1.15.1/reference/arrays.classes.html#numpy.class.__array_priority__ '''returnpow(self/other,-1)
# __rdiv__
[docs]def__rtruediv__(self,other):'''Reverse truedivision to make sure that this method is called rather than the __mul__ of a numpy array the class constant __array_priority__ must be set > 0 https://docs.scipy.org/doc/numpy-1.15.1/reference/arrays.classes.html#numpy.class.__array_priority__ '''returnself.__rdiv__(other)
[docs]def__rpow__(self,other):'''Reverse power to make sure that this method is called rather than the __mul__ of a numpy array the class constant __array_priority__ must be set > 0 https://docs.scipy.org/doc/numpy-1.15.1/reference/arrays.classes.html#numpy.class.__array_priority__ '''returnother.power(self)
[docs]def__iadd__(self,other):'''Inline addition'''ifisinstance(other,BlockDataContainer):forel,otinzip(self.containers,other.containers):el+=otelifisinstance(other,Number):forelinself.containers:el+=otherelifisinstance(other,list)orisinstance(other,numpy.ndarray):ifnotself.is_compatible(other):raiseValueError('Incompatible for __iadd__')forel,otinzip(self.containers,other):el+=otreturnself
# __iadd__
[docs]def__isub__(self,other):'''Inline subtraction'''ifisinstance(other,BlockDataContainer):forel,otinzip(self.containers,other.containers):el-=otelifisinstance(other,Number):forelinself.containers:el-=otherelifisinstance(other,list)orisinstance(other,numpy.ndarray):ifnotself.is_compatible(other):raiseValueError('Incompatible for __isub__')forel,otinzip(self.containers,other):el-=otreturnself
# __isub__
[docs]def__imul__(self,other):'''Inline multiplication'''ifisinstance(other,BlockDataContainer):forel,otinzip(self.containers,other.containers):el*=otelifisinstance(other,Number):forelinself.containers:el*=otherelifisinstance(other,list)orisinstance(other,numpy.ndarray):ifnotself.is_compatible(other):raiseValueError('Incompatible for __imul__')forel,otinzip(self.containers,other):el*=otreturnself
# __imul__
[docs]def__idiv__(self,other):'''Inline division'''ifisinstance(other,BlockDataContainer):forel,otinzip(self.containers,other.containers):el/=otelifisinstance(other,Number):forelinself.containers:el/=otherelifisinstance(other,list)orisinstance(other,numpy.ndarray):ifnotself.is_compatible(other):raiseValueError('Incompatible for __idiv__')forel,otinzip(self.containers,other):el/=otreturnself