# Copyright 2020 United Kingdom Research and Innovation
# Copyright 2020 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
from scipy.fftpack import fftshift, ifftshift, fft, ifft
import numpy as np
import pywt
from cil.framework import Processor, ImageData, AcquisitionData
[docs]
class RingRemover(Processor):
'''
RingRemover Processor: Removes vertical stripes from a DataContainer(ImageData/AcquisitionData)
using the algorithm in https://doi.org/10.1364/OE.17.008567
Parameters
----------
decNum : int
Number of wavelet decompositions - increasing the number of decompositions, increases the strength of the ring removal
but can alter the profile of the data
wname : str
Name of wavelet filter from pywt e.g. 'db1' -- 'db35', 'haar' - increasing the wavelet filter increases the strength of
the ring removal, but also increases the computational effort
sigma : float
Damping parameter in Fourier space - increasing sigma, increases the size of artefacts which can be removed
info : boolean
Flag to enable print of ring remover end message
Returns
-------
DataContainer
Corrected ImageData/AcquisitionData 2D, 3D, multi-spectral 2D, multi-spectral 3D
'''
def __init__(self, decNum=4, wname='db10', sigma=1.5, info = True):
kwargs = {'decNum': decNum,
'wname': wname,
'sigma': sigma,
'info': info}
super(RingRemover, self).__init__(**kwargs)
def check_input(self, dataset):
if not ((isinstance(dataset, ImageData)) or
(isinstance(dataset, AcquisitionData))):
raise Exception('Processor supports only following data types:\n' +
' - ImageData\n - AcquisitionData')
elif (dataset.geometry == None):
raise Exception('Geometry is not defined.')
else:
return True
def process(self, out = None):
data = self.get_input()
decNum = self.decNum
wname = self.wname
sigma = self.sigma
info = self.info
# acquisition geometry from sinogram
geom = data.geometry
# get channels, vertical info
channels = geom.channels
vertical = geom.pixel_num_v
# allocate datacontainer space
if out is None:
out = 0.*data
# for non multichannel data
if 'channel' not in geom.dimension_labels:
# for 3D data
if 'vertical' in geom.dimension_labels:
for i in range(vertical):
tmp_corrected = self._xRemoveStripesVertical(data.get_slice(vertical=i, force=True).as_array(), decNum, wname, sigma)
out.fill(tmp_corrected, vertical = i)
# for 2D data
else:
tmp_corrected = self._xRemoveStripesVertical(data.as_array(), decNum, wname, sigma)
out.fill(tmp_corrected)
# for multichannel data
else:
# for 3D data
if 'vertical' in geom.dimension_labels:
for i in range(channels):
out_ch_i = out.get_slice(channel=i)
data_ch_i = data.get_slice(channel=i)
for j in range(vertical):
tmp_corrected = self._xRemoveStripesVertical(data_ch_i.get_slice(vertical=j, force=True).as_array(), decNum, wname, sigma)
out_ch_i.fill(tmp_corrected, vertical = j)
out.fill(out_ch_i.as_array(), channel=i)
if info:
print("Finish channel {}".format(i))
# for 2D data
else:
for i in range(channels):
tmp_corrected = self._xRemoveStripesVertical(data.get_slice(channel=i).as_array(), decNum, wname, sigma)
out.fill(tmp_corrected, channel = i)
if info:
print("Finish channel {}".format(i))
if info:
print("Finish Ring Remover")
return out
def _xRemoveStripesVertical(self, ima, decNum, wname, sigma):
''' Ring removal algorithm via combined wavelet and fourier filtering
code from https://doi.org/10.1364/OE.17.008567
translated in Python
Parameters
----------
ima : ndarray
2D image data
decNum : int
Number of wavelet decompositions - increasing the number of decompositions, increases the strength of the ring removal
but can alter the profile of the data
wname : str
Name of wavelet filter from pywt e.g. 'db1' -- 'db35', 'haar' - increasing the wavelet filter increases the strength of
the ring removal, but also increases the computational effort
sigma : float
Damping parameter in Fourier space - increasing sigma, increase the size of artefacts which can be removed
Returns
-------
Corrected 2D sinogram data (Numpy Array)
'''
original_extent = [slice(None, ima.shape[0], None), slice(None, ima.shape[1], None)]
# allocate cH, cV, cD
Ch = [None]*decNum
Cv = [None]*decNum
Cd = [None]*decNum
# wavelets decomposition
for i in range(decNum):
ima, (Ch[i], Cv[i], Cd[i]) = pywt.dwt2(ima,wname)
# FFT transform of horizontal frequency bands
for i in range(decNum):
# use to axis=0, which correspond to the angles direction
fCv = fftshift(fft(Cv[i], axis=0))
my, mx = fCv.shape
# damping of vertical stripe information
damp = 1 - np.exp(-np.array([range(-int(np.floor(my/2)),-int(np.floor(my/2))+my)])**2/(2*sigma**2))
fCv *= damp.T
# inverse FFT
Cv[i] = np.real(ifft(ifftshift(fCv), axis=0))
# wavelet reconstruction
nima = ima
for i in range(decNum-1,-1,-1):
nima = nima[0:Ch[i].shape[0],0:Ch[i].shape[1]]
nima = pywt.idwt2((nima,(Ch[i],Cv[i],Cd[i])),wname)
# if the original input is odd, the signal reconstructed with idwt2 will have one extra sample, which can be discarded
nima = nima[original_extent[0], original_extent[1]]
return nima