[15]:
# -*- coding: utf-8 -*-
#  Copyright 2024 -  United Kingdom Research and Innovation
#  Copyright 2024 -  The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
#   Authored by:    CIL contributors

CIL Callback demonstration#

This notebook runs on CIL Master (built on 14/03/2024) and demonstrates the new callback functionality

[16]:
from cil.utilities import dataexample
from cil.utilities.display import show2D
from cil.recon import FDK
from cil.processors import TransmissionAbsorptionConverter, Slicer
from cil.utilities.quality_measures import psnr
import numpy as np
import matplotlib.pyplot as plt
from cil.plugins.tigre import ProjectionOperator
from cil.optimisation.algorithms import FISTA, Algorithm
from cil.optimisation.functions import LeastSquares, IndicatorBox, ZeroFunction, TotalVariation
from cil.optimisation.operators import GradientOperator
from cil.optimisation.utilities import callbacks
from cil.framework import DataContainer

from cil.utilities.quality_measures import mse, mae, psnr

# set up default colour map for visualisation
cmap = "gray"

# set the backend for FBP and the ProjectionOperator
device = 'gpu'


Load Data#

[17]:

#%% Load data ground_truth = dataexample.SIMULATED_SPHERE_VOLUME.get() data = dataexample.SIMULATED_CONE_BEAM_DATA.get() twoD = True if twoD: data = data.get_slice(vertical='centre') ground_truth = ground_truth.get_slice(vertical='centre') absorption = TransmissionAbsorptionConverter()(data) absorption = Slicer(roi={'angle':(0, -1, 5)})(absorption) ig = ground_truth.geometry #%% recon = FDK(absorption, image_geometry=ig).run() #%% show2D([ground_truth, recon], title = ['Ground Truth', 'FDK Reconstruction'], origin = 'upper', num_cols = 2) # %%
FDK recon

Input Data:
        angle: 60
        horizontal: 128

Reconstruction Volume:
        horizontal_y: 128
        horizontal_x: 128

Reconstruction Options:
        Backend: tigre
        Filter: ram-lak
        Filter cut-off frequency: 1.0
        FFT order: 8
        Filter_inplace: False

../../_images/demos_callback_demonstration_4_1.png
[17]:
<cil.utilities.display.show2D at 0x7f0908e75840>

Default behaviour#

[18]:


alpha=0.1 A = ProjectionOperator(image_geometry=ig, acquisition_geometry=absorption.geometry) F = LeastSquares(A = A, b = absorption) G = alpha*TotalVariation(lower=0) algo=FISTA(initial=ig.allocate(0), f=F, g=G) algo.run(500) show2D([ground_truth, recon, algo.solution], title = ['Ground Truth', 'FDK Reconstruction', 'TV solution'], origin = 'upper', num_cols = 3)
../../_images/demos_callback_demonstration_6_1.png
[18]:
<cil.utilities.display.show2D at 0x7f0908e74b80>

Other provided CIL callbacks#

[19]:
algo=FISTA(initial=ig.allocate(0), f=F, g=G, update_objective_interval=10)
algo.run(500, callbacks=[callbacks.ProgressCallback(), callbacks.TextProgressCallback()])
show2D([ground_truth, recon, algo.solution], title = ['Ground Truth', 'FDK Reconstruction', 'TV solution'], origin = 'upper', num_cols = 3)

     0/500        ?it/s
    10/500    29.53it/s, objective=+8.586e+01
    20/500    29.53it/s, objective=+9.047e+00
    30/500    28.65it/s, objective=+2.640e+00
    40/500    28.53it/s, objective=+1.188e+00
    50/500    18.75it/s, objective=+6.928e-01
    60/500    21.45it/s, objective=+4.585e-01
    70/500    23.55it/s, objective=+3.451e-01
    80/500    25.28it/s, objective=+2.741e-01
    90/500    26.58it/s, objective=+2.337e-01
   100/500    27.57it/s, objective=+2.062e-01
   110/500    28.17it/s, objective=+1.870e-01
   120/500    20.05it/s, objective=+1.729e-01
   130/500    22.35it/s, objective=+1.623e-01
   140/500    24.32it/s, objective=+1.541e-01
   150/500    25.89it/s, objective=+1.476e-01
   160/500    26.91it/s, objective=+1.421e-01
   170/500    27.82it/s, objective=+1.375e-01
   180/500    28.58it/s, objective=+1.336e-01
   190/500    20.24it/s, objective=+1.303e-01
   200/500    22.51it/s, objective=+1.275e-01
   210/500    24.47it/s, objective=+1.250e-01
   220/500    26.06it/s, objective=+1.228e-01
   230/500    27.32it/s, objective=+1.208e-01
   240/500    28.08it/s, objective=+1.190e-01
   250/500    28.55it/s, objective=+1.173e-01
   260/500    28.69it/s, objective=+1.159e-01
   270/500    28.86it/s, objective=+1.145e-01
   280/500    29.03it/s, objective=+1.134e-01
   290/500    29.44it/s, objective=+1.123e-01
   300/500    29.27it/s, objective=+1.113e-01
   310/500    29.51it/s, objective=+1.104e-01
   320/500    21.28it/s, objective=+1.096e-01
   330/500    22.71it/s, objective=+1.088e-01
   340/500    24.61it/s, objective=+1.081e-01
   350/500    26.25it/s, objective=+1.075e-01
   360/500    27.47it/s, objective=+1.068e-01
   370/500    28.42it/s, objective=+1.063e-01
   380/500    29.10it/s, objective=+1.057e-01
   390/500    20.90it/s, objective=+1.052e-01
   400/500    23.14it/s, objective=+1.048e-01
   410/500    24.79it/s, objective=+1.043e-01
   420/500    25.48it/s, objective=+1.039e-01
   430/500    26.67it/s, objective=+1.035e-01
   440/500    27.03it/s, objective=+1.031e-01
   450/500    27.67it/s, objective=+1.028e-01
   460/500    27.81it/s, objective=+1.025e-01
   470/500    28.66it/s, objective=+1.022e-01
   480/500    29.28it/s, objective=+1.019e-01
   490/500    29.76it/s, objective=+1.016e-01
   500/500    30.16it/s, objective=+1.013e-01
   500/500    26.07it/s

../../_images/demos_callback_demonstration_8_2.png
[19]:
<cil.utilities.display.show2D at 0x7f0908e75030>
[20]:
algo.run(100, callbacks=[callbacks.ProgressCallback(), callbacks.TextProgressCallback()])
   501/600        ?it/s
   511/600    30.76it/s, objective=+1.011e-01
   521/600    30.74it/s, objective=+1.008e-01
   531/600    30.69it/s, objective=+1.006e-01
   541/600    30.26it/s, objective=+1.004e-01
   551/600    28.95it/s, objective=+1.002e-01
   561/600    20.01it/s, objective=+1.000e-01
   571/600    22.57it/s, objective=+9.986e-02
   581/600    24.69it/s, objective=+9.969e-02
   591/600    26.41it/s, objective=+9.953e-02
   600/600    26.40it/s, objective=+9.938e-02

Early stopping (a custom callback example)#

[21]:
class EarlyStopping(callbacks.Callback):
         def __call__(self, algorithm):
               if algorithm.objective[-1] <= 2e-1:  # arbitrary stopping criterion
                  raise StopIteration

algo=FISTA(initial=ig.allocate(0), f=F, g=G, update_objective_interval=10)
algo.run(500, callbacks=[callbacks.TextProgressCallback(), EarlyStopping()])
show2D([ground_truth, recon, algo.solution], title = ['Ground Truth', 'FDK Reconstruction', 'TV solution'], origin = 'upper', num_cols = 3)

     0/500        ?it/s
    10/500    31.71it/s, objective=+8.586e+01
    20/500    31.18it/s, objective=+9.047e+00
    30/500    29.77it/s, objective=+2.640e+00
    40/500    30.06it/s, objective=+1.188e+00
    50/500    30.28it/s, objective=+6.928e-01
    60/500    30.35it/s, objective=+4.585e-01
    70/500    30.46it/s, objective=+3.451e-01
    80/500    30.57it/s, objective=+2.741e-01
    90/500    22.27it/s, objective=+2.337e-01
   100/500    22.69it/s, objective=+2.062e-01
   110/500    24.61it/s, objective=+1.870e-01
   110/500    26.86it/s

../../_images/demos_callback_demonstration_11_1.png
[21]:
<cil.utilities.display.show2D at 0x7f08d5daaec0>
[22]:
class EarlyStopping(callbacks.Callback):
         def __call__(self, algorithm):
            if np.mean((algorithm.x.array-ground_truth.array)**2) <= 3e-8:  # arbitrary stopping criterion
                  raise StopIteration

algo=FISTA(initial=ig.allocate(0), f=F, g=G, update_objective_interval=10)
algo.run(500, callbacks=[callbacks.TextProgressCallback(), EarlyStopping()])
show2D([ground_truth, recon, algo.solution], title = ['Ground Truth', 'FDK Reconstruction', 'TV solution'], origin = 'upper', num_cols = 3)
     0/500        ?it/s
    10/500    23.79it/s, objective=+8.586e+01
    20/500    26.96it/s, objective=+9.047e+00
    23/500    26.89it/s

../../_images/demos_callback_demonstration_12_1.png
[22]:
<cil.utilities.display.show2D at 0x7f08d5b68d90>

Calculating data discrepancy at each iteration (A custom callback example)#

[23]:
class DataDiscrepancyCallback(callbacks.Callback):
    def __init__(self, A, data):
        self.f = LeastSquares(A, data)
        self.save_values=[]

    def __call__(self, algorithm):
        self.save_values.append(self.f(algorithm.get_output()))

mycallback_FISTA_lower_bound= DataDiscrepancyCallback(A, absorption)
algo1=FISTA(initial=ig.allocate(0), f=F, g=alpha*TotalVariation(lower=0), update_objective_interval=10)
algo1.run(500, callbacks=[mycallback_FISTA_lower_bound])


mycallback_FISTA_no_lower_bound= DataDiscrepancyCallback(A, absorption)
algo2=FISTA(initial=ig.allocate(0), f=F, g=alpha*TotalVariation(), update_objective_interval=10)
algo2.run(500, callbacks=[mycallback_FISTA_no_lower_bound])


show2D([ground_truth, algo1.get_output(), algo2.get_output()], title=['ground_truth', 'FISTA_lower_bound', 'FISTA_no_lower_bound'], num_cols=3)
show2D([absorption, A.direct(algo1.get_output())-absorption, A.direct(algo2.get_output())-absorption], title=['ground_truth', 'Data error FISTA_lower_bound', 'Data error FISTA_no_lower_bound'], fix_range=[[0,3], [-0.02, 0.02], [-0.02, 0.02]], cmap=['gray', 'seismic', 'seismic'], num_cols=3)
plt.plot(range(10,501), mycallback_FISTA_lower_bound.save_values[10:], label='FISTA TV with lower bound ')
plt.plot(range(10, 501), mycallback_FISTA_no_lower_bound.save_values[10:], label='FISTA TV without lower bound ')
plt.yscale('log')
plt.ylabel('Data discrepancy $\|Ax-y\|_2^2$')
plt.xlabel('Iteration')
plt.legend()
../../_images/demos_callback_demonstration_14_0.png
../../_images/demos_callback_demonstration_14_1.png
[23]:
<matplotlib.legend.Legend at 0x7f08df8487f0>
../../_images/demos_callback_demonstration_14_3.png

We see that the without the lower bound, the reconstruction overfits to the noisy absorption data

Calculating a noise approximation for each iteration (A custom callback example)#

[24]:
import skimage

class SigmaEstimateCallback(callbacks.Callback):
    def __init__(self):

        self.save_values=[]

    def __call__(self, algorithm):
        self.save_values.append(skimage.restoration.estimate_sigma(algorithm.get_output().as_array()))

mycallback_FISTA_TV_alpha_01= SigmaEstimateCallback()
algo1=FISTA(initial=ig.allocate(0), f=F, g=0.1*TotalVariation(lower=0), update_objective_interval=10)
algo1.run(500, callbacks=[mycallback_FISTA_TV_alpha_01])


mycallback_FISTA_TV_alpha_1= SigmaEstimateCallback()
algo2=FISTA(initial=ig.allocate(0), f=F, g=1*TotalVariation(lower=0), update_objective_interval=10)
algo2.run(500, callbacks=[mycallback_FISTA_TV_alpha_1])


show2D([ground_truth, algo1.get_output(), algo2.get_output()], title=['ground_truth', 'FISTA_TV_alpha_01', 'FISTA_TV_alpha_1'], num_cols=3)
show2D([absorption, A.direct(algo1.get_output())-absorption, A.direct(algo2.get_output())-absorption], title=['ground_truth', 'Data error FISTA_TV_alpha_01', 'Data error FISTA_TV_alpha_1'], fix_range=[[0,3], [-0.02, 0.02], [-0.02, 0.02]], cmap=['gray', 'seismic', 'seismic'], num_cols=3)
plt.plot(range(10,501), mycallback_FISTA_TV_alpha_01.save_values[10:], label='FISTA TV alpha=0.1 ')
plt.plot(range(10, 501), mycallback_FISTA_TV_alpha_1.save_values[10:], label='FISTA TV alpha=1.0 ')
plt.ylabel('Noise Estimate')
plt.xlabel('Iteration')
plt.legend()
/home/bih17925/miniconda3/envs/cil_testing2/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3432: RuntimeWarning: Mean of empty slice.
  return _methods._mean(a, axis=axis, dtype=dtype,
/home/bih17925/miniconda3/envs/cil_testing2/lib/python3.10/site-packages/numpy/core/_methods.py:190: RuntimeWarning: invalid value encountered in divide
  ret = ret.dtype.type(ret / rcount)
../../_images/demos_callback_demonstration_17_1.png
../../_images/demos_callback_demonstration_17_2.png
[24]:
<matplotlib.legend.Legend at 0x7f08dfcb7670>
../../_images/demos_callback_demonstration_17_4.png

We see with a larger regularisation parameter, the resulting image is less noisy.

Image metric callbacks (custom callback example)#

[25]:

class MetricsDiagnostics(callbacks.Callback): def __init__(self, reference_image, metrics_dict, print_interval=1): # reference image as numpy (level) array self.reference_image = reference_image self.metrics_dict = metrics_dict # if data_range is None: # self.data_range = np.abs(self.reference_image.max() - self.reference_image.min()) self.computed_metrics = [] self.print_interval=print_interval super(MetricsDiagnostics, self).__init__() def __call__(self, algo): for metric_name, metric_func in self.metrics_dict.items(): if not hasattr(algo, metric_name): setattr(algo, metric_name, []) metric_list = getattr(algo, metric_name) metric_value = metric_func(self.reference_image, algo.get_output()) metric_list.append(metric_value) self.computed_metrics.append(metric_value) if algo.iteration == 0: print (self.callback_header()) print(self.callback_iteration()) def callback_header(self): return " ".join("{:>20}".format(metric_name) for metric_name in self.metrics_dict.keys()) def callback_iteration(self): if isinstance(self.computed_metrics, list): # Handle list of metrics return " ".join("{:>20.5e}".format(metric) for metric in self.computed_metrics[-len(self.metrics_dict):]) else: # Handle single metric return "{:>20.5e}".format(self.computed_metrics) from cil.utilities.quality_measures import mae, psnr, mse metric_callback= MetricsDiagnostics(ground_truth, {'MSE':mse, 'MAE':mae, 'PSNR':psnr}) algo=FISTA(initial=ig.allocate(0), f=F, g=G, update_objective_interval=10) algo.run(100, callbacks=[metric_callback])
                 MSE                  MAE                 PSNR
         1.07888e-06          5.48145e-04          9.48530e+00
         5.85316e-07          6.22034e-04          1.21411e+01
         5.05844e-07          5.72563e-04          1.27749e+01
         4.31374e-07          5.19819e-04          1.34665e+01
         3.64704e-07          4.67054e-04          1.41956e+01
         3.06416e-07          4.16492e-04          1.49519e+01
         2.56388e-07          3.70092e-04          1.57261e+01
         2.14156e-07          3.28810e-04          1.65077e+01
         1.78987e-07          2.92725e-04          1.72868e+01
         1.50022e-07          2.60981e-04          1.80535e+01
         1.26383e-07          2.33361e-04          1.87981e+01
         1.07187e-07          2.09652e-04          1.95136e+01
         9.16141e-08          1.89309e-04          2.01954e+01
         7.89449e-08          1.72049e-04          2.08418e+01
         6.85910e-08          1.57283e-04          2.14524e+01
         6.00884e-08          1.44610e-04          2.20271e+01
         5.30737e-08          1.33746e-04          2.25662e+01
         4.72670e-08          1.24414e-04          2.30695e+01
         4.24393e-08          1.16364e-04          2.35374e+01
         3.84176e-08          1.09416e-04          2.39697e+01
         3.50543e-08          1.03451e-04          2.43676e+01
         3.22300e-08          9.82989e-05          2.47324e+01
         2.98493e-08          9.39156e-05          2.50657e+01
         2.78304e-08          9.01341e-05          2.53698e+01
         2.61075e-08          8.68679e-05          2.56474e+01
         2.46249e-08          8.40164e-05          2.59013e+01
         2.33423e-08          8.14809e-05          2.61336e+01
         2.22266e-08          7.92211e-05          2.63463e+01
         2.12462e-08          7.72332e-05          2.65422e+01
         2.03792e-08          7.54337e-05          2.67232e+01
         1.96080e-08          7.38151e-05          2.68907e+01
         1.89173e-08          7.23520e-05          2.70464e+01
         1.82934e-08          7.10307e-05          2.71921e+01
         1.77264e-08          6.98001e-05          2.73288e+01
         1.72101e-08          6.86725e-05          2.74572e+01
         1.67384e-08          6.76756e-05          2.75779e+01
         1.63068e-08          6.67997e-05          2.76913e+01
         1.59109e-08          6.59966e-05          2.77981e+01
         1.55498e-08          6.52429e-05          2.78978e+01
         1.52207e-08          6.45565e-05          2.79907e+01
         1.49199e-08          6.39012e-05          2.80774e+01
         1.46448e-08          6.32729e-05          2.81582e+01
         1.43935e-08          6.26837e-05          2.82334e+01
         1.41640e-08          6.21308e-05          2.83032e+01
         1.39533e-08          6.16084e-05          2.83683e+01
         1.37602e-08          6.11234e-05          2.84288e+01
         1.35827e-08          6.06739e-05          2.84852e+01
         1.34200e-08          6.02613e-05          2.85375e+01
         1.32710e-08          5.98831e-05          2.85860e+01
         1.31342e-08          5.95365e-05          2.86310e+01
         1.30086e-08          5.92132e-05          2.86727e+01
         1.28935e-08          5.89066e-05          2.87113e+01
         1.27882e-08          5.86154e-05          2.87469e+01
         1.26929e-08          5.83402e-05          2.87794e+01
         1.26069e-08          5.81077e-05          2.88090e+01
         1.25294e-08          5.79025e-05          2.88357e+01
         1.24593e-08          5.77139e-05          2.88601e+01
         1.23964e-08          5.75408e-05          2.88821e+01
         1.23400e-08          5.73717e-05          2.89019e+01
         1.22899e-08          5.72179e-05          2.89196e+01
         1.22457e-08          5.70800e-05          2.89352e+01
         1.22065e-08          5.69476e-05          2.89491e+01
         1.21716e-08          5.68219e-05          2.89616e+01
         1.21399e-08          5.67079e-05          2.89729e+01
         1.21121e-08          5.66082e-05          2.89828e+01
         1.20881e-08          5.65168e-05          2.89914e+01
         1.20672e-08          5.64386e-05          2.89990e+01
         1.20490e-08          5.63735e-05          2.90055e+01
         1.20338e-08          5.63197e-05          2.90110e+01
         1.20213e-08          5.62742e-05          2.90155e+01
         1.20117e-08          5.62335e-05          2.90190e+01
         1.20049e-08          5.61994e-05          2.90215e+01
         1.20006e-08          5.61720e-05          2.90230e+01
         1.19991e-08          5.61517e-05          2.90236e+01
         1.19998e-08          5.61385e-05          2.90233e+01
         1.20029e-08          5.61309e-05          2.90222e+01
         1.20088e-08          5.61240e-05          2.90201e+01
         1.20170e-08          5.61242e-05          2.90171e+01
         1.20275e-08          5.61325e-05          2.90133e+01
         1.20408e-08          5.61499e-05          2.90085e+01
         1.20565e-08          5.61750e-05          2.90028e+01
         1.20747e-08          5.62071e-05          2.89963e+01
         1.20954e-08          5.62405e-05          2.89888e+01
         1.21182e-08          5.62744e-05          2.89806e+01
         1.21432e-08          5.63137e-05          2.89717e+01
         1.21702e-08          5.63569e-05          2.89620e+01
         1.21990e-08          5.64026e-05          2.89518e+01
         1.22295e-08          5.64532e-05          2.89410e+01
         1.22611e-08          5.65052e-05          2.89297e+01
         1.22934e-08          5.65577e-05          2.89183e+01
         1.23272e-08          5.66137e-05          2.89064e+01
         1.23621e-08          5.66716e-05          2.88941e+01
         1.23983e-08          5.67352e-05          2.88814e+01
         1.24357e-08          5.68040e-05          2.88683e+01
         1.24743e-08          5.68758e-05          2.88549e+01
         1.25140e-08          5.69482e-05          2.88411e+01
         1.25548e-08          5.70229e-05          2.88269e+01
         1.25965e-08          5.71005e-05          2.88125e+01
         1.26388e-08          5.71802e-05          2.87980e+01
         1.26821e-08          5.72615e-05          2.87831e+01
         1.27264e-08          5.73452e-05          2.87680e+01

More complex example, image metric callbacks with region of interests#

Warning - this is a complex example! But the code may be useful to adapt and reuse

[26]:
class ImageQualityCallback(callbacks.Callback):
    """
      Parameters
    ----------

    reference_image: CIL or STIR ImageData
      containing the reference image used to calculate the metrics

    roi_mask_dict : dictionary of ImageData objects
      list containing one binary ImageData object for every ROI to be
      evaluated. Voxels with values 1 are considered part of the ROI
      and voxels with value 0 are not.
      Dimension of the ROI mask images must be the same as the dimension of
      the reference image.

    metrics_dict : dictionary of lambda functions f(x,y) mapping
      two 1-dimensional numpy arrays x and y to a scalar value or a
      numpy.ndarray.
      x and y can be the voxel values of the whole images or the values of
      voxels in a ROI such that the metric can be computed on the whole
      images and optionally in the ROIs separately.

      E.g. f(x,y) could be MSE(x,y), PSNR(x,y), MAE(x,y)

    statistics_dict : dictionary of lambda functions f(x) mapping a
      1-dimensional numpy array x to a scalar value or a numpy.ndarray.
      E.g. mean(x), std_deviation(x) that calculate global and / or
      ROI mean and standard deviations.

      E.g. f(x) could be x.mean()




      """

    def __init__(self, reference_image,
                       roi_mask_dict   = None,
                       metrics_dict    = None,
                       statistics_dict = None,
                       ):

        # the reference image
        self.reference_image = reference_image


        self.roi_indices_dict = {}
        self.roi_store=[]



        self.roi_mask_dict=roi_mask_dict


        self.metrics_dict = metrics_dict
        self.metrics_store={}
        for key, value in self.metrics_dict.items():
                self.metrics_store['global_'+key] = []
                if roi_mask_dict is not None:
                  for roi_name, value in roi_mask_dict.items():
                    self.metrics_store[roi_name+'_'+key] = []

        self.statistics_dict = statistics_dict
        self.stat_store={}
        for key, value in self.statistics_dict.items():
                self.stat_store['global_'+key] = []
                if roi_mask_dict is not None:
                  for roi_name, value in roi_mask_dict.items():
                    self.stat_store[roi_name+'_'+key] = []

    def __call__(self, algorithm):
            if self.metrics_dict is not None:
                for metric_name, metric in self.metrics_dict.items():
                    ans = metric(self.reference_image, algorithm.x)
                    self.metrics_store['global_'+metric_name].append(ans)


                    for roi_name, roi in self.roi_mask_dict.items():
                      ans = metric(self.reference_image, algorithm.x, mask=roi)
                      self.metrics_store[roi_name+'_'+metric_name].append(ans)



            if self.statistics_dict is not None:
                for statistic_name, stat in self.statistics_dict.items():
                    ans = stat( algorithm.x.array, np._NoValue)
                    self.stat_store['global_'+statistic_name].append(ans)


                    for roi_name, roi in self.roi_mask_dict.items():
                        ans = stat( algorithm.x.array, roi.array.astype('bool'))
                        self.stat_store[roi_name+'_'+statistic_name].append(ans)


[27]:
def mse(dc1, dc2, mask=None):
    ''' Calculates the mean squared error of two images

    Parameters
    ----------
    dc1: `DataContainer`
        One image to be compared
    dc2: `DataContainer`
        Second image to be compared
    mask: array or `DataContainer` with the same dimensions as the `dc1` and `dc2`
        The pixelwise operation only considers values where the mask is True or NonZero.

    Returns
    -------
    A number, the mean squared error of the two images
    '''
    dc1 = dc1.as_array()
    dc2 = dc2.as_array()

    if mask is not None:

        if isinstance(mask, DataContainer):
            mask = mask.as_array()

        mask = mask.astype('bool')
        dc1 = np.extract(mask, dc1)
        dc2 = np.extract(mask, dc2)
    return np.mean(((dc1 - dc2)**2))


def mae(dc1, dc2, mask=None):
    ''' Calculates the Mean Absolute error of two images.

    Parameters
    ----------
    dc1: `DataContainer`
        One image to be compared
    dc2: `DataContainer`
        Second image to be compared
    mask: array or `DataContainer` with the same dimensions as the `dc1` and `dc2`
        The pixelwise operation only considers values where the mask is True or NonZero.


    Returns
    -------
    A number with the mean absolute error between the two images.
    '''
    dc1 = dc1.as_array()
    dc2 = dc2.as_array()

    if mask is not None:

        if isinstance(mask, DataContainer):
            mask = mask.as_array()

        mask = mask.astype('bool')
        dc1 = np.extract(mask, dc1)
        dc2 = np.extract(mask, dc2)

    return np.mean(np.abs((dc1-dc2)))


def psnr(ground_truth, corrupted, mask=None):
    ''' Calculates the Peak signal to noise ratio (PSNR) between the two images.

    Parameters
    ----------
    ground_truth: `DataContainer`
        The reference image
    corrupted: `DataContainer`
        The image to be evaluated
    data_range: scalar value, default=None
        PSNR scaling factor, the dynamic range of the images (i.e., the difference between the maximum the and minimum allowed values). We take the maximum value in the ground truth array.
    mask: array or `DataContainer` with the same dimensions as the `dc1` and `dc2`
        The pixelwise operation only considers values where the mask is True or NonZero..

    Returns
    -------
    A number, the peak signal to noise ration between the two images.
    '''


    if mask is None:
        data_range = ground_truth.as_array().max()


    else:

        if isinstance(mask, DataContainer):
            mask = mask.as_array()
        data_range = np.max(ground_truth.as_array(),
                                where=mask.astype('bool'), initial=-1e-8)


    tmp_mse = mse(ground_truth, corrupted, mask=mask)

    return 10 * np.log10((data_range ** 2) / tmp_mse)
[28]:
#%% create masks
top = ig.allocate(0)
bottom = ig.allocate(0)

top.fill(
    np.asarray(ground_truth.array > 0.8 * ground_truth.max(),
               dtype=np.float32)
    )
bottom.fill(
    np.asarray(np.invert(ground_truth.array < 0.4 * ground_truth.max()),
               dtype=np.float32)
)



roi_image_dict = {
    'top' : top,
    'bottom' : bottom
}

show2D([ground_truth, top, bottom], num_cols=3)
../../_images/demos_callback_demonstration_24_0.png
[28]:
<cil.utilities.display.show2D at 0x7f08dcf79120>
[29]:
img_qual_callback = ImageQualityCallback(ground_truth,
                                              roi_mask_dict = roi_image_dict,
                                              metrics_dict = {'MSE':mse,
                                                              'MAE':mae,
                                                              'PSNR':psnr},
                                              statistics_dict = {'MEAN': (lambda x, y: np.mean(x, where=y)),
                                                                 'STDDEV': (lambda x, y: np.std(x, where=y)),
                                                                 'MAX': (lambda x, y: np.max(x, where=y, initial=0))},
                                              )
[30]:
algo=FISTA(initial=ig.allocate(0), f=F, g=G, update_objective_interval=10)
algo.run(500, callbacks=[img_qual_callback])
show2D([ground_truth, recon, algo.solution], title = ['Ground Truth', 'FDK Reconstruction', 'TV solution'], origin = 'upper', num_cols = 3)
../../_images/demos_callback_demonstration_26_0.png
[30]:
<cil.utilities.display.show2D at 0x7f08dcda0af0>
[31]:
plt.plot(range(501), img_qual_callback.metrics_store['global_MSE'])
[31]:
[<matplotlib.lines.Line2D at 0x7f08dec65b10>]
../../_images/demos_callback_demonstration_27_1.png
[32]:
plt.plot(range(501), img_qual_callback.metrics_store['top_PSNR'], label='Top')
plt.plot(range(501), img_qual_callback.metrics_store['global_PSNR'], label='Global')
plt.plot(range(501), img_qual_callback.metrics_store['bottom_PSNR'], label='Bottom')
plt.legend()
[32]:
<matplotlib.legend.Legend at 0x7f08dec67bb0>
../../_images/demos_callback_demonstration_28_1.png
[ ]: