# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
from cil.framework import Processor, AcquisitionData
from cil.utilities import multiprocessing as cil_mp
import numpy
import logging
import matplotlib.pyplot as plt
import numba
log = logging.getLogger(__name__)
[docs]
class FluxNormaliser(Processor):
r'''
Flux normalisation based on float or region of interest
This processor reads in an AcquisitionData and normalises it by flux from
a float or array of float values, or the mean flux in a region of interest.
Each projection is divided by its flux value and multiplied by the target.
Parameters:
-----------
flux: float or list of floats, optional
Array of floats that describe the variation in brightness of the unobstructed
beam between projections. Must have length equal to the number of projections
in the dataset, or be a single float. If flux=None, calculate
flux from the roi.
roi: dict, optional
Dictionary describing the region of interest containing the background
in the image from which to extract the flux. The roi is specified as
`{'horizontal':(start,stop), 'vertical':(start,stop)}`. If an axis is
not specified in the roi dictionary, the full range will be used.
target: {'mean', 'first', 'last'} or float, default='mean'
The target of the normalised data. If string the data is scaled by the
'mean', 'first' or 'last' flux value. If float, the data is scaled
by the float value.
Default is 'mean'
accelerated: bool, optional
Specify whether to use multi-threading using numba.
Default is True
Returns:
--------
Output: AcquisitionData normalised by flux
Example
-------
This example passes the flux as a list the same size as the data, and
specifies the target='first' which scales all projections to the first flux
value 0.9
>>> from cil.processors import FluxNormaliser
>>> processor = FluxNormaliser(flux=[0.9, 1.0, 1.1, 0.8], target='first')
>>> processor.set_input(data)
>>> data_norm = processor.get_output()
Example
-------
This example calculates the flux from a region of interest for each projection
and scales all projections to the mean flux
>>> from cil.processors import FluxNormaliser
>>> processor = FluxNormaliser(roi={'horizontal':(5, 15)}, target='mean')
>>> processor.set_input(data)
>>> data_norm = processor.get_output()
Note
----
The roi indices provided are start inclusive, stop exclusive.
All elements along a dimension will be included if the axis does not appear
in the roi dictionary
'''
def __init__(self, flux=None, roi=None, target='mean', accelerated=True):
kwargs = {
'flux' : flux,
'roi' : roi,
'roi_slice' : None,
'roi_axes' : None,
'target' : target,
'target_value' : None,
'v_size' : 1,
'v_axis' : None,
'h_size' : 1,
'h_axis' : None,
'_accelerated' : accelerated
}
super(FluxNormaliser, self).__init__(**kwargs)
def check_input(self, dataset):
if self.roi is not None and self.flux is not None:
raise ValueError("Please specify either flux or roi, not both")
if self.roi is None and self.flux is None:
raise ValueError("Please specify either flux or roi, found None")
if not (type(dataset), AcquisitionData):
raise TypeError("Expected AcquistionData, found {}"
.format(type(dataset)))
image_axes = 0
if 'vertical' in dataset.dimension_labels:
self.v_axis = dataset.get_dimension_axis('vertical')
self.v_size = dataset.get_dimension_size('vertical')
image_axes += 1
if 'horizontal' in dataset.dimension_labels:
self.h_axis = dataset.get_dimension_axis('horizontal')
self.h_size = dataset.get_dimension_size('horizontal')
image_axes += 1
if (( self.h_axis is not None) and (self.h_axis < (len(dataset.shape)-image_axes))) or \
((self.v_axis is not None) and self.v_axis < (len(dataset.shape)-image_axes)):
raise ValueError('Projections must be the last two axes of the dataset')
return True
def _calculate_flux(self):
'''
Function to calculate flux from a region of interest in the data. If the
flux is already provided as an array, convert the array to float 32 and
check the size matches the number of projections
'''
dataset = self.get_input()
if dataset is None:
raise ValueError('Data not found, please run `set_input(data)`')
# Calculate the flux from the roi in the data
if self.flux is None:
if isinstance(self.roi, dict):
if not all (r in dataset.dimension_labels for r in self.roi):
raise ValueError("roi labels must be in the dataset dimension_labels, found {}"
.format(str(self.roi)))
slc = [slice(None)]*len(dataset.shape)
axes=[]
for r in self.roi:
# only allow roi to be specified in horizontal and vertical
if (r != 'horizontal' and r != 'vertical'):
raise ValueError("roi must be 'horizontal' or 'vertical', found '{}'"
.format(str(r)))
for d in ['horizontal', 'vertical']:
if d in self.roi:
# check indices are ints
if not all(isinstance(i, int) for i in self.roi[d]):
raise TypeError("roi values must be int, found {} and {}"
.format(str(type(self.roi[d][0])), str(type(self.roi[d][1]))))
# check indices are in range
elif (self.roi[d][0] >= self.roi[d][1]) or (self.roi[d][0] < 0) or self.roi[d][1] > dataset.get_dimension_size(d):
raise ValueError("roi values must be start > stop and between 0 and {}, found start={} and stop={} for direction '{}'"
.format(str(dataset.get_dimension_size(d)), str(self.roi[d][0]), str(self.roi[d][1]), d ))
# create slice
else:
ax = dataset.get_dimension_axis(d)
slc[ax] = slice(self.roi[d][0], self.roi[d][1])
axes.append(ax)
# if a projection dimension isn't in the roi, use the whole axis
else:
if d in dataset.dimension_labels:
ax = dataset.get_dimension_axis(d)
axes.append(ax)
self.roi.update({d:(0,dataset.get_dimension_size(d))})
self.flux = numpy.mean(dataset.array[tuple(slc)], axis=tuple(axes))
# Warn if the flux is more than 10% of the dataset range
dataset_range = numpy.max(dataset.array, axis=tuple(axes)) - numpy.min(dataset.array, axis=tuple(axes))
if (numpy.mean(self.flux) > dataset.mean()):
if numpy.mean(self.flux/dataset_range) < 0.9:
log.warning('Warning: mean value in selected roi is more than 10 percent of data range - may not represent the background')
else:
if numpy.mean(self.flux/dataset_range) > 0.1:
log.warning('Warning: mean value in selected roi is more than 10 percent of data range - may not represent the background')
self.roi_slice = slc
self.roi_axes = axes
else:
raise TypeError("roi must be a dictionary, found {}"
.format(str(type(self.roi))))
# convert flux array to float32
self.flux = numpy.array(self.flux, dtype=numpy.float32, ndmin=1)
# check flux array is the right size
flux_size_flat = len(self.flux.ravel())
if flux_size_flat > 1:
data_size_flat = len(dataset.geometry.angles)*dataset.geometry.channels
if data_size_flat != flux_size_flat:
raise ValueError("Flux must be a scalar or array with length \
\n = number of projections, found {} and {}"
.format(flux_size_flat, data_size_flat))
# check if flux array contains 0s
if 0 in self.flux:
raise ValueError('Flux value can\'t be 0, provide a different flux\
or region of interest with non-zero values')
def _calculate_target(self):
'''
Calculate the target value for the normalisation
'''
if self.flux is None:
raise ValueError('Flux not found')
if isinstance(self.target, (int,float)):
self.target_value = self.target
elif isinstance(self.target, str):
if self.target == 'first':
if len(numpy.shape(self.flux)) > 0 :
self.target_value = self.flux.flat[0]
else:
self.target_value = self.flux
elif self.target == 'last':
if len(numpy.shape(self.flux)) > 0 :
self.target_value = self.flux.flat[-1]
else:
self.target_value = self.flux
elif self.target == 'mean':
self.target_value = numpy.mean(self.flux.ravel())
else:
raise ValueError("Target string not recognised, found {}, expected 'first' or 'mean'"
.format(self.target))
else:
raise TypeError("Target must be string or a number, found {}"
.format(type(self.target)))
[docs]
def preview_configuration(self, angle=None, channel=None, log=False):
'''
Preview the FluxNormalisation processor configuration for roi mode.
Plots the region of interest on the image and the mean, maximum and
minimum intensity in the roi.
Parameters:
-----------
angle: float, optional
Index of the angle to plot, default=None displays the data with the
minimum and maximum pixel values in the roi. For 2D data, the roi is
plotted on the sinogram.
channel: int, optional
The channel to plot, default=None displays the central channel if
the data has channels
log: bool, default=False
If True, plot the image with a log scale, default is False
Returns:
--------
matplotlib.figure.Figure
The figure object created to plot the configuration
'''
self._calculate_flux()
if self.roi_slice is None:
raise ValueError('Preview available with roi, run `processor= FluxNormaliser(roi=roi)` then `set_input(data)`')
else:
data = self.get_input()
min = numpy.min(data.array[tuple(self.roi_slice)], axis=tuple(self.roi_axes))
max = numpy.max(data.array[tuple(self.roi_slice)], axis=tuple(self.roi_axes))
if 'channel' in data.dimension_labels:
if channel is None:
channel = int(data.get_dimension_size('channel')/2)
channel_axis = data.get_dimension_axis('channel')
flux_array = self.flux.take(indices=channel, axis=channel_axis)
min = min.take(indices=channel, axis=channel_axis)
max = max.take(indices=channel, axis=channel_axis)
else:
if channel is not None:
raise ValueError("Channel not found")
else:
flux_array = self.flux
plt.figure(figsize=(8,8))
if data.geometry.dimension == '3D':
if angle is None:
if 'angle' in data.dimension_labels:
self._plot_slice_roi(angle_index=numpy.argmin(min), channel_index=channel, log=log, ax=221)
self._plot_slice_roi(angle_index=numpy.argmax(max), channel_index=channel, log=log, ax=222)
else:
self._plot_slice_roi(log=log, channel_index=channel, ax=211)
else:
if 'angle' in data.dimension_labels:
self._plot_slice_roi(angle_index=angle, channel_index=channel, log=log, ax=211)
else:
self._plot_slice_roi(log=log, channel_index=channel, ax=211)
# if data is 2D plot roi on all angles
elif data.geometry.dimension == '2D':
if angle is None:
self._plot_slice_roi(channel_index=channel, log=log, ax=211)
else:
raise ValueError("Cannot plot ROI for a single angle on 2D data, please specify angle=None to plot ROI on the sinogram")
plt.subplot(212)
if len(data.geometry.angles)==1:
plt.plot(0, flux_array, '.r', label='Mean')
plt.plot(0, min,'.k', label='Minimum')
plt.plot(0, max,'.k', label='Maximum')
else:
indices = range(data.get_dimension_size('angle'))
plt.plot(indices, flux_array, 'r', label='Mean')
plt.plot(indices, min,'--k', label='Minimum')
plt.plot(indices, max,'--k', label='Maximum')
plt.legend()
plt.xlabel('angle index')
plt.ylabel('Intensity in roi')
plt.grid()
ax1 = plt.gca()
ax2 = ax1.twiny()
valid_ticks = [int(tick) for tick in ax1.get_xticks() if 0 <= tick < len(data.geometry.angles)]
ax2.set_xticks(valid_ticks)
ax2.set_xbound(ax1.get_xbound())
ax2.set_xticklabels([data.geometry.angles[tick] for tick in valid_ticks])
ax2.set_xlabel('angle')
plt.tight_layout()
fig = plt.gcf()
plt.show()
return fig
def _plot_slice_roi(self, angle_index=None, channel_index=None, log=False, ax=111):
'''
Plot the region of interest on a data slice
Parameters:
-----------
angle_index: int, optional
Index of the angle to plot
channel_index: int, optional
Index of the channel to plot
log: bool, optional
Plot the log of the slice intensity to highlight small variations
ax: int, default=111
The subplot axis to display the slice on
'''
data = self.get_input()
if angle_index is not None and 'angle' in data.dimension_labels:
data_slice = data.get_slice(angle=angle_index)
else:
data_slice = data
if 'channel' in data.dimension_labels:
data_slice = data_slice.get_slice(channel=channel_index)
if len(data_slice.shape) != 2:
raise ValueError("Data shape not compatible with preview_configuration(), data must have at least two of 'horizontal', 'vertical' and 'angle'")
# if horizontal and vertical are not specified in the roi, get the
# min and max extent from the full size of the dimension
extent = [0, data_slice.shape[1], 0, data_slice.shape[0]]
if 'angle' in data_slice.dimension_labels:
min_angle = data_slice.geometry.angles[0]
max_angle = data_slice.geometry.angles[-1]
for i, d in enumerate(data_slice.dimension_labels):
if d !='angle':
extent[i*2]=min_angle
extent[i*2+1]=max_angle
# plot the specified data slice
ax1 = plt.subplot(ax)
if log:
im = ax1.imshow(numpy.log(data_slice.array), cmap='gray',aspect='equal', origin='lower', extent=extent)
plt.gcf().colorbar(im, ax=ax1)
else:
im = ax1.imshow(data_slice.array, cmap='gray',aspect='equal', origin='lower', extent=extent)
plt.gcf().colorbar(im, ax=ax1)
h = data_slice.dimension_labels[1]
v = data_slice.dimension_labels[0]
# get the box to plot from the roi
if h == 'angle':
h_min = min_angle
h_max = max_angle
else:
h_min = self.roi[h][0]
h_max = self.roi[h][1]
if v == 'angle':
v_min = min_angle
v_max = max_angle
else:
v_min = self.roi[v][0]
v_max = self.roi[v][1]
# plot the roi box
ax1.plot([h_min, h_max],[v_min, v_min],'--r')
ax1.plot([h_min, h_max],[v_max, v_max],'--r')
ax1.plot([h_min, h_min],[v_min, v_max],'--r')
ax1.plot([h_max, h_max],[v_min, v_max],'--r')
title = 'ROI'
if angle_index is not None:
title += ' angle = ' + str(data.geometry.angles[angle_index])
if channel_index is not None:
title += ' channel = ' + str(channel_index)
ax1.set_title(title)
ax1.set_xlabel(h)
ax1.set_ylabel(v)
def process(self, out=None):
self._calculate_flux()
self._calculate_target()
data = self.get_input()
if out is None:
out = data.copy()
elif id(out) != id(data):
numpy.copyto(out.array, data.array)
proj_size = self.v_size*self.h_size
num_proj = int(data.array.size / proj_size)
if self._accelerated:
num_threads_original = numba.get_num_threads()
numba.set_num_threads(cil_mp.NUM_THREADS)
numba_loop(self.flux, self.target_value, num_proj, proj_size, out.array)
# reset the number of threads to the original value
numba.set_num_threads(num_threads_original)
else:
serial_loop(self.flux, self.target_value, num_proj, proj_size, out.array)
return out
@numba.njit(parallel=True)
def numba_loop(flux, target, num_proj, proj_size, out):
out_flat = out.ravel()
flux_flat = flux.ravel()
if len(flux) == 1:
norm = target/flux_flat[0]
for i in numba.prange(num_proj):
for ij in range(proj_size):
out_flat[i*proj_size+ij] *= norm
else:
for i in numba.prange(num_proj):
for ij in range(proj_size):
out_flat[i*proj_size+ij] *= (target/flux_flat[i])
def serial_loop(flux, target, num_proj, proj_size, out):
out_reshaped = out.reshape(num_proj, proj_size)
flux_flat = flux.ravel()
norm = target / flux_flat[:, numpy.newaxis] # shape: (num_proj, 1)
numpy.multiply(out_reshaped, norm, out=out_reshaped)