Source code for cil.processors.RingRemover

#  Copyright 2020 United Kingdom Research and Innovation
#  Copyright 2020 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt

from scipy.fftpack import fftshift, ifftshift, fft, ifft
import numpy as np
import pywt
from cil.framework import Processor, ImageData, AcquisitionData

[docs] class RingRemover(Processor): ''' RingRemover Processor: Removes vertical stripes from a DataContainer(ImageData/AcquisitionData) using the algorithm in https://doi.org/10.1364/OE.17.008567 Parameters ---------- decNum : int Number of wavelet decompositions - increasing the number of decompositions, increases the strength of the ring removal but can alter the profile of the data wname : str Name of wavelet filter from pywt e.g. 'db1' -- 'db35', 'haar' - increasing the wavelet filter increases the strength of the ring removal, but also increases the computational effort sigma : float Damping parameter in Fourier space - increasing sigma, increases the size of artefacts which can be removed info : boolean Flag to enable print of ring remover end message Returns ------- DataContainer Corrected ImageData/AcquisitionData 2D, 3D, multi-spectral 2D, multi-spectral 3D ''' def __init__(self, decNum=4, wname='db10', sigma=1.5, info = True): kwargs = {'decNum': decNum, 'wname': wname, 'sigma': sigma, 'info': info} super(RingRemover, self).__init__(**kwargs) def check_input(self, dataset): if not ((isinstance(dataset, ImageData)) or (isinstance(dataset, AcquisitionData))): raise Exception('Processor supports only following data types:\n' + ' - ImageData\n - AcquisitionData') elif (dataset.geometry == None): raise Exception('Geometry is not defined.') else: return True def process(self, out = None): data = self.get_input() decNum = self.decNum wname = self.wname sigma = self.sigma info = self.info # acquisition geometry from sinogram geom = data.geometry # get channels, vertical info channels = geom.channels vertical = geom.pixel_num_v # allocate datacontainer space if out is None: out = 0.*data # for non multichannel data if 'channel' not in geom.dimension_labels: # for 3D data if 'vertical' in geom.dimension_labels: for i in range(vertical): tmp_corrected = self._xRemoveStripesVertical(data.get_slice(vertical=i, force=True).as_array(), decNum, wname, sigma) out.fill(tmp_corrected, vertical = i) # for 2D data else: tmp_corrected = self._xRemoveStripesVertical(data.as_array(), decNum, wname, sigma) out.fill(tmp_corrected) # for multichannel data else: # for 3D data if 'vertical' in geom.dimension_labels: for i in range(channels): out_ch_i = out.get_slice(channel=i) data_ch_i = data.get_slice(channel=i) for j in range(vertical): tmp_corrected = self._xRemoveStripesVertical(data_ch_i.get_slice(vertical=j, force=True).as_array(), decNum, wname, sigma) out_ch_i.fill(tmp_corrected, vertical = j) out.fill(out_ch_i.as_array(), channel=i) if info: print("Finish channel {}".format(i)) # for 2D data else: for i in range(channels): tmp_corrected = self._xRemoveStripesVertical(data.get_slice(channel=i).as_array(), decNum, wname, sigma) out.fill(tmp_corrected, channel = i) if info: print("Finish channel {}".format(i)) if info: print("Finish Ring Remover") return out def _xRemoveStripesVertical(self, ima, decNum, wname, sigma): ''' Ring removal algorithm via combined wavelet and fourier filtering code from https://doi.org/10.1364/OE.17.008567 translated in Python Parameters ---------- ima : ndarray 2D image data decNum : int Number of wavelet decompositions - increasing the number of decompositions, increases the strength of the ring removal but can alter the profile of the data wname : str Name of wavelet filter from pywt e.g. 'db1' -- 'db35', 'haar' - increasing the wavelet filter increases the strength of the ring removal, but also increases the computational effort sigma : float Damping parameter in Fourier space - increasing sigma, increase the size of artefacts which can be removed Returns ------- Corrected 2D sinogram data (Numpy Array) ''' original_extent = [slice(None, ima.shape[0], None), slice(None, ima.shape[1], None)] # allocate cH, cV, cD Ch = [None]*decNum Cv = [None]*decNum Cd = [None]*decNum # wavelets decomposition for i in range(decNum): ima, (Ch[i], Cv[i], Cd[i]) = pywt.dwt2(ima,wname) # FFT transform of horizontal frequency bands for i in range(decNum): # use to axis=0, which correspond to the angles direction fCv = fftshift(fft(Cv[i], axis=0)) my, mx = fCv.shape # damping of vertical stripe information damp = 1 - np.exp(-np.array([range(-int(np.floor(my/2)),-int(np.floor(my/2))+my)])**2/(2*sigma**2)) fCv *= damp.T # inverse FFT Cv[i] = np.real(ifft(ifftshift(fCv), axis=0)) # wavelet reconstruction nima = ima for i in range(decNum-1,-1,-1): nima = nima[0:Ch[i].shape[0],0:Ch[i].shape[1]] nima = pywt.idwt2((nima,(Ch[i],Cv[i],Cd[i])),wname) # if the original input is odd, the signal reconstructed with idwt2 will have one extra sample, which can be discarded nima = nima[original_extent[0], original_extent[1]] return nima