Source code for cil.utilities.display

#  Copyright 2019 United Kingdom Research and Innovation
#  Copyright 2019 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
# Kyle Pidgeon (UKRI-STFC)


#%%
from cil.framework import AcquisitionGeometry, AcquisitionData, ImageData, DataContainer, BlockDataContainer
from cil.framework.labels import AcquisitionType
import numpy as np
import warnings

import os
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from mpl_toolkits.axes_grid1 import make_axes_locatable
from itertools import cycle
import logging

log = logging.getLogger(__name__)


class _PlotData(object):
    def __init__(self, data, title, axis_labels, origin):
        self.data = data
        self.title = title
        self.axis_labels = axis_labels
        self.origin = origin
        self.range = None

def set_origin(data, origin):
    shape_v = [0, data.shape[0]]
    shape_h = [0, data.shape[1]]

    if type(data) != np.ndarray:
        data = data.as_array()

    data_origin='lower'

    if 'upper' in origin:
        shape_v.reverse()
        data_origin='upper'

    if 'right' in origin:
        shape_h.reverse()
        data = np.flip(data,1)

    extent = (*shape_h,*shape_v)
    return data, data_origin, extent

class show_base(object):
    def save(self,filename, **kwargs):
        '''
        Saves the image using matplotlib.figure.savefig(). Default format is png if no extension is provided.

        matplotlib kwargs can be passed, refer to documentation
        https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html
        '''

        file,extension = os.path.splitext(os.path.abspath(filename))
        extension = extension.strip('.')

        extensions = plt.gcf().canvas.get_supported_filetypes()
        extensions = [i for i in extensions.keys()]

        format = kwargs.get('format',None)

        if format is None:
            if extension == '':
                extension = 'png'
        else:
            extension = format

        if extension not in extensions:
            raise ValueError("Extension not valid. Got {0}, backend supports {1}".format(extension,extensions))

        try:
            path_full = file+'.'+extension
            self.figure.set_tight_layout(True)
            self.figure.set_facecolor('w')
            self.figure.savefig(path_full, bbox_inches='tight',**kwargs)
            print("Saved image as {}".format(path_full))
        except PermissionError:
            print("Unable to save image. Permissions denied: {}".format(path_full))
        except:
            print("Unable to save image")


[docs] class show1D(show_base): """ Displays the 1D line profile of a N-dimensional dataset along a specified axis. If multiple slices are provided, a plot will be generated for each slice. If multiple datasets are provided, the same slice will be applied to each dataset. Parameters ---------- data : ndarray i.e. DataContainer, numpy.ndarray or a list of ndarrays The data to be sliced and plotted. If a list of DataContainers is passed, they must all have the same shape. slice_list : list of tuples, or list of list of tuple, optional. Describes the slice to be taken from the data to extract a 1D vector. If None, the centre line of each axis will be displayed. Each slice should be a list of tuples, where each tuple is a (axis, index). Each axis can be described as a string or integer representing the axis to slice along. To plot multiple slices, pass a list of slices. dataset_labels : list of str, optional Overrides the default legend labels for each dataset displayed in the plot legend. Must be a list of strings equal to the number of datasets. title: list of str, optional Overrides the default title for each plot. Must be a list of strings equal to the number of plots. line_colours : str, list of str, default=None Overrides the default colour(s) for each line plot. Must be a list of colours equal to the number of datasets. line_styles : {"-","--","-.",":"}, list of {"-","--","-.",":"}, default=None Overrides the default line style(s) for each line plot. Must be a list of styles equal to the number of datasets. axis_labels : tuple of str, list of tuples, optional Overrides the default axis labels in the form (x_axis_label, y_axis_label). Must be a single tuple or a list of tuples equal to the number of plots. size : tuple, default=(8,3) The size of each sub-plot in the figure. Note ---- The figure can be saved using the `save` method . i.e: >>> fig = show1D(data, title='My Plot') >>> fig.save('/path/to/output.png') Examples -------- This example uses a 3D dataset with no slicing information provided. The function displays 3 subplots with the centre line profile shown along each axis. >>> from cil.utilities.display import show1D >>> from cil.utilities import dataexample >>> data = dataexample.SIMULATED_SPHERE_VOLUME.get() >>> show1D(data, slice_list=None) The following example shows a single line profile from a 3D dataset along the vertical axis. >>> from cil.utilities.display import show1D >>> from cil.utilities import dataexample >>> data = dataexample.SIMULATED_SPHERE_VOLUME.get() >>> show1D(data, slice_list=[("Horizontal_x", 30),("Horizontal_y", 40)]) The following example shows multiple line profiles from a 3D dataset along the vertical axis. >>> from cil.utilities.display import show1D >>> from cil.utilities import dataexample >>> data = dataexample.SIMULATED_SPHERE_VOLUME.get() >>> show1D(data, slice_list=[[("Vertical",62),("Horizontal_y",70)],[("Vertical",61),("Horizontal_y",70)],[("Vertical",67),("Horizontal_x",70)]]) The following example plots the same line profile from two datasets on the same plot. >>> from cil.utilities.display import show1D >>> from cil.utilities import dataexample >>> data = dataexample.SIMULATED_SPHERE_VOLUME.get() >>> data2 = data*0.9 >>> show1D([data, data2], slice_list=[("Horizontal_x", 30),("Horizontal_y", 40)]) The following example shows the case where the input data is a numpy array. The slice_list can be provided as (axis, index), where axis is an integer. >>> from cil.utilities.display import show1D >>> import numpy as np >>> data = np.random.rand(10,10,10) >>> show1D(data, slice_list=[(0, 3),(2, 5)]) To save the figure, use the `save` method. >>> from cil.utilities.display import show1D >>> from cil.utilities import dataexample >>> data = dataexample.SIMULATED_SPHERE_VOLUME.get() >>> fig = show1D(data, slice_list=None) >>> fig.save('/path/to/output.png') """ def __init__(self, data, slice_list=None, dataset_labels=None, title=None, line_colours=None, line_styles=None, axis_labels=None, size=(8,3)): self.figure = self._show1d(data, slice_list, dataset_labels=dataset_labels, title=title, line_colours=line_colours, line_styles=line_styles, axis_labels=axis_labels, size=size) def _parse_slice(self, slice_list_in, items_per_slice, ndim, labels): """ Internal function to parse the slice_list input and convert it to a slice object to be applied to the data. """ # check enough slices are provided to output a single vector if len(slice_list_in) != items_per_slice: raise ValueError("slice_list must provide a slice for ndim - 1 axes") # convert axis labels to indices axis_indices = [] for sl in slice_list_in: try: ind = int(sl[0]) except: if sl[0] in labels: ind = labels.index(sl[0]) else: raise ValueError(f"Invalid axis label: {sl[0]}") axis_indices.append(ind) # check axis are unique if len(set([index for index in axis_indices])) != items_per_slice: raise ValueError("slice_list contains duplicate axes. Each axis must be unique.") slice_ndim = [slice(None)]*ndim for i, sl in enumerate(slice_list_in): slice_ndim[axis_indices[i]] = slice(sl[1], sl[1] + 1) slice_list_out = tuple(slice_ndim) return slice_list_out def _show1d(self, data, slice_list=None, dataset_labels='default', title=None, line_colours=None, line_styles=None, axis_labels='default', size=(8,3)): """ Internal function to display 1D plots of pixel flux from multi-dimensional data and slicing information. """ # check datasets if hasattr(data, 'ndim'): data_list = [data] num_data = 1 else: data_list = data num_data = len(data) for i in range(1, num_data): if data_list[i].shape != data_list[0].shape: raise ValueError("All datasets must have the same shape") data_ndim = data_list[0].ndim data_shape = data_list[0].shape items_per_slice = len(data_shape) - 1 try: labels = data_list[0].dimension_labels except: labels = [f"Axis {i}" for i in range(data_ndim)] if slice_list is None: # show the centre slice of each axes _slice_list = [None]*data_ndim for i in range(data_ndim): slice_ndim = [slice(None)]*data_ndim for j in range(data_ndim): if i != j: slice_ndim[j] = slice(data_shape[j]//2, data_shape[j]//2 + 1) _slice_list[i] = tuple(slice_ndim) elif isinstance(slice_list, list): if isinstance(slice_list[0], tuple): # single slice only _slice_list = [None]*1 _slice_list[0] = self._parse_slice(slice_list, items_per_slice, data_ndim, labels) elif isinstance(slice_list[0], list): # multiple slices _slice_list = [None]*len(slice_list) for i in range(len(slice_list)): _slice_list[i] = self._parse_slice(slice_list[i], items_per_slice, data_ndim, labels) else: raise ValueError("slice_list must contain a valid list of slices to generate 1D plots") # set up plots CB_PALETTE = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00'] if line_colours is not None and len(line_colours) < num_data: log.warning ("line_colours must be a list of colours at least as long as the number of datasets, using default colour palette") line_colours = None if line_styles is not None and len(line_styles) < num_data: log.warning ("line_styles must be a list of styles at least as long as the number of datasets, using default line styles") line_styles = None num_sub_plots = len(_slice_list) fig_size = (size[0], size[1] * num_sub_plots) fig, (ax) = plt.subplots(num_sub_plots, figsize=fig_size) if num_sub_plots == 1: ax = np.array(ax) for i in range(num_sub_plots): #reset the colour and line style colour_cyc = cycle(CB_PALETTE) ls_cyc = cycle(["-","--","-.",":"]) # extract the vectors for j, x in enumerate(data_list): try: x = x.as_array() except: pass _cl = next(colour_cyc) if line_colours is None else line_colours[j] _ls = next(ls_cyc) if line_styles is None else line_styles[j] data_line = np.squeeze(x[_slice_list[i]]) if dataset_labels is None: _lbl = f'Dataset {j}' elif isinstance(dataset_labels, list) and len(dataset_labels) == num_data: _lbl = dataset_labels[j] else: raise ValueError("dataset_labels must be a list of strings equal to the number of datasets") ax.flat[i].plot(data_line, color=_cl, linestyle=_ls, label=_lbl) # get the unsliced index subplot_title = f"Slice at " for j, sl in enumerate(_slice_list[i]): if sl == slice(None): x_label = f"{labels[j]} index" else: subplot_title += f"{labels[j]}:{sl.start}, " # remove last two characters subplot_title = subplot_title[:-2] if title is not None: if isinstance(title, list) and len(title) == num_sub_plots: subplot_title = title[i] elif isinstance(title, str): subplot_title = title else: raise ValueError("title must be a list of strings equal to the number of plots") ax.flat[i].set_title(subplot_title) if axis_labels is not None: if isinstance(axis_labels, tuple): ax.flat[i].set_xlabel(axis_labels[0]) ax.flat[i].set_ylabel(axis_labels[1]) elif isinstance(axis_labels, list) and len(axis_labels) == num_sub_plots: ax.flat[i].set_xlabel(axis_labels[i][0]) ax.flat[i].set_ylabel(axis_labels[i][1]) else: raise ValueError("axis_labels must be a tuple or a list of tuples equal to the number of plots") else: ax.flat[i].set_xlabel(x_label) ax.flat[i].set_ylabel("Value") ax.flat[i].set_xlim(0, len(data_line)-1) if num_data > 1: ax.flat[i].legend(loc='center left', bbox_to_anchor=(1, 0.5)) fig.set_tight_layout(True) fig.set_facecolor('w') #plt.show() creates a new figure so we save a copy to return fig2 = plt.gcf() plt.show() return fig2
[docs] class show2D(show_base): '''This plots 2D slices from cil DataContainer types. Plots 1 or more 2D plots in an (n x num_cols) matrix. Can plot multiple slices from one 3D dataset, or compare multiple datasets Inputs can be single arguments or list of arguments that will be sequentially applied to subplots If no slice_list is passed a 3D dataset will display the centre slice of the outer dimension, a 4D dataset will show the centre slices of the two outer dimension. Parameters ---------- datacontainers: ImageData, AcquisitionData, list of ImageData / AcquisitionData, BlockDataContainer The DataContainers to be displayed title: string, list of strings, optional The title for each figure slice_list: tuple, int, list of tuples, list of ints, optional The slices to show. A list of integers will show slices for the outer dimension. For 3D datacontainers single slice: (direction, index). For 4D datacontainers two slices: [(direction0, index),(direction1, index)]. fix_range: boolean, tuple, list of tuples Sets the display range of the data. `True` sets all plots to the global (min, max). axis_labels: tuple, list of tuples, optional The axis labels for each figure e.g. ('x','y') origin: string, list of strings Sets the display origin. 'lower/upper-left/right' cmap: str, list or tuple of strings Sets the colour map of the plot (see matplotlib.pyplot). If passed a list or tuple of the length of datacontainers, allows to set a different color map for each datacontainer. num_cols: int Sets the number of columns of subplots to display size: tuple Figure size in inches Note ---- The figure can be saved using the `save` method . i.e: >>> fig = show2D(data, title='My Plot') >>> fig.save('/path/to/output.png') ''' def __init__(self,datacontainers, title=None, slice_list=None, fix_range=False, axis_labels=None, origin='lower-left', cmap='gray', num_cols=2, size=(15,15)): self.figure = self.__show2D(datacontainers, title=title, slice_list=slice_list, fix_range=fix_range, axis_labels=axis_labels, origin=origin, cmap=cmap, num_cols=num_cols, size=size) def __show2D(self,datacontainers, title=None, slice_list=None, fix_range=False, axis_labels=None, origin='lower-left', cmap='gray', num_cols=2, size=(15,15)): #get number of subplots, number of input datasets, or number of slices requested if isinstance(datacontainers, (list, BlockDataContainer)): num_plots = len(datacontainers) else: dim = len(datacontainers.shape) if slice_list is None or dim == 2: num_plots = 1 elif type(slice_list) is tuple: num_plots = 1 elif dim == 4 and type(slice_list[0]) is tuple: num_plots = 1 else: num_plots = len(slice_list) subplots = [] #range needs subsetted data range_min = float("inf") range_max = -range_min #set up, all inputs can be 1 or num_plots for i in range(num_plots): #get data per subplot, subset where required if isinstance(datacontainers, (list, BlockDataContainer)): data = datacontainers[i] else: data = datacontainers if len(data.shape) ==4: if slice_list is None or type(slice_list) is tuple or type(slice_list[0]) is tuple: #none, (direction, ind) or [(direction0, ind), (direction1, ind)] apply to all datasets slice_requested = slice_list elif type(slice_list[i]) == int or len(slice_list[i]) > 1: # [ind0, ind1, ind2] of direction0, or [[(direction0, ind), (direction1, ind)],[(direction0, ind), (direction1, ind)]] slice_requested = slice_list[i] else: slice_requested = slice_list[i][0] # [[(direction0, ind)],[(direction0, ind)]] cut_axis = [0,1] cut_slices = [data.shape[0]//2, data.shape[1]//2] if type(slice_requested) is int: #use axis 0, slice val cut_slices[0] = slice_requested elif type(slice_requested) is tuple: #get axis ind, # if 0 default 1 # if 1 default 0 axis = slice_requested[0] if slice_requested[0] is str: axis = data.dimension_labels.index(axis) if axis == 0: cut_axis[0] = slice_requested[0] cut_slices[0] = slice_requested[1] else: cut_axis[1] = slice_requested[0] cut_slices[1] = slice_requested[1] elif type(slice_requested) is list: #use full input cut_axis[0] = slice_requested[0][0] cut_axis[1] = slice_requested[1][0] cut_slices[0] = slice_requested[0][1] cut_slices[1] = slice_requested[1][1] if cut_axis[0] > cut_axis[1]: cut_axis.reverse() cut_slices.reverse() try: if hasattr(data, 'get_slice'): if type(cut_axis[0]) is int: cut_axis[0] = data.dimension_labels[cut_axis[0]] if type(cut_axis[1]) is int: cut_axis[1] = data.dimension_labels[cut_axis[1]] temp_dict = {cut_axis[0]:cut_slices[0], cut_axis[1]:cut_slices[1]} plot_data = data.get_slice(**temp_dict, force=True) elif hasattr(data,'as_array'): plot_data = data.as_array().take(indices=cut_slices[1], axis=cut_axis[1]) plot_data = plot_data.take(indices=cut_slices[0], axis=cut_axis[0]) else: plot_data = data.take(indices=cut_slices[1], axis=cut_axis[1]) plot_data = plot_data.take(indices=cut_slices[0], axis=cut_axis[0]) except: raise TypeError("Unable to slice input data. Could not obtain 2D slice {0} from {1} with shape {2}.\n\ Pass either correct slice information or a 2D array".format(slice_requested, type(data), data.shape)) subtitle = "direction: ({0},{1}), slice: ({2},{3})".format(*cut_axis, * cut_slices) elif len(data.shape) == 3: #get slice list per subplot if type(slice_list) is list: #[(direction, ind), (direction, ind)], [ind0, ind1, ind2] of direction0 slice_requested = slice_list[i] else: #(direction, ind) single tuple apply to all datasets slice_requested = slice_list #default axis 0, centre slice cut_slice = data.shape[0]//2 cut_axis = 0 if type(slice_requested) is int: #use axis 0, slice val cut_slice = slice_requested if type(slice_requested) is tuple: cut_slice = slice_requested[1] cut_axis = slice_requested[0] try: if hasattr(data, 'get_slice'): if type(cut_axis) is int: cut_axis = data.dimension_labels[cut_axis] temp_dict = {cut_axis:cut_slice} plot_data = data.get_slice(**temp_dict, force=True) elif hasattr(data,'as_array'): plot_data = data.as_array().take(indices=cut_slice, axis=cut_axis) else: plot_data = data.take(indices=cut_slice, axis=cut_axis) except: raise TypeError("Unable to slice input data. Could not obtain 2D slice {0} from {1} with shape {2}.\n\ Pass either correct slice information or a 2D array".format(slice_requested, type(data), data.shape)) subtitle = "direction: {0}, slice: {1}".format(cut_axis,cut_slice) else: plot_data = data subtitle = None #check dataset is now 2D if len(plot_data.shape) != 2: raise TypeError("Unable to slice input data. Could not obtain 2D slice {0} from {1} with shape {2}.\n\ Pass either correct slice information or a 2D array".format(slice_requested, type(data), data.shape)) #get axis labels per subplot if type(axis_labels) is list: plot_axis_labels = axis_labels[i] else: plot_axis_labels = axis_labels if plot_axis_labels is None and hasattr(plot_data,'dimension_labels'): plot_axis_labels = (plot_data.dimension_labels[1],plot_data.dimension_labels[0]) #get min/max of subsetted data range_min = min(range_min, plot_data.min()) range_max = max(range_max, plot_data.max()) #get title per subplot if isinstance(title, list): if title[i] is None: plot_title = '' else: plot_title = title[i] else: if title is None: plot_title = '' else: plot_title = title if subtitle is not None: plot_title += '\n' + subtitle #get origin per subplot if isinstance(origin, list): plot_origin = origin[i] else: plot_origin = origin subplots.append(_PlotData(plot_data,plot_title,plot_axis_labels, plot_origin)) #set range per subplot for i, subplot in enumerate(subplots): if fix_range is False: pass elif fix_range is True: subplot.range = (range_min,range_max) elif type(fix_range) is list: subplot.range = fix_range[i] else: subplot.range = (fix_range[0], fix_range[1]) #create plots if num_plots < num_cols: num_cols = num_plots num_rows = int(round((num_plots+0.5)/num_cols)) fig, (ax) = plt.subplots(num_rows, num_cols, figsize=size) axes = ax.flatten() #set up plots for i in range(num_rows*num_cols): axes[i].set_visible(False) for i, subplot in enumerate(subplots): axes[i].set_visible(True) axes[i].set_title(subplot.title) if subplot.axis_labels is not None: axes[i].set_ylabel(subplot.axis_labels[1]) axes[i].set_xlabel(subplot.axis_labels[0]) #set origin data, data_origin, extent = set_origin(subplot.data, subplot.origin) if isinstance(cmap, (list, tuple)): dcmap = cmap[i] else: dcmap = cmap sp = axes[i].imshow(data, cmap=dcmap, origin=data_origin, extent=extent) im_ratio = subplot.data.shape[0]/subplot.data.shape[1] y_axes2 = False if isinstance(subplot.data,(AcquisitionData)): if axes[i].get_ylabel() == 'angle': locs = axes[i].get_yticks() location_new = locs[0:-1].astype(int) ang = subplot.data.geometry.config.angles labels_new = ["{:.2f}".format(i) for i in np.take(ang.angle_data, location_new)] axes[i].set_yticks(location_new, labels=labels_new) axes[i].set_ylabel('angle / ' + str(ang.angle_unit)) y_axes2 = axes[i].axes.secondary_yaxis('right') y_axes2.set_ylabel('angle / index') if subplot.data.shape[0] < subplot.data.shape[1]//2: axes[i].set_aspect(1/im_ratio) im_ratio = 1 if y_axes2: scale = 0.041*im_ratio pad = 0.12 else: scale = 0.0467*im_ratio pad = 0.02 plt.colorbar(sp, orientation='vertical', ax=axes[i],fraction=scale, pad=pad) if subplot.range is not None: sp.set_clim(subplot.range[0],subplot.range[1]) fig.set_tight_layout(True) fig.set_facecolor('w') #plt.show() creates a new figure so we save a copy to return fig2 = plt.gcf() plt.show() return fig2
def plotter2D(datacontainers, title=None, slice_list=None, fix_range=False, axis_labels=None, origin='lower-left', cmap='gray', num_cols=2, size=(15,15)): '''Alias of show2D''' return show2D(datacontainers, title=title, slice_list=slice_list, fix_range=fix_range, axis_labels=axis_labels, origin=origin, cmap=cmap, num_cols=num_cols, size=size) class _Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def draw(self, renderer): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) FancyArrowPatch.draw(self, renderer) def do_3d_projection(self, renderer=None): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) self.set_positions((xs[0],ys[0]),(xs[1],ys[1])) return np.min(zs) class _ShowGeometry(object): def __init__(self, acquisition_geometry, image_geometry=None): if AcquisitionType.DIM2 & acquisition_geometry.dimension: self.ndim = 2 sys = acquisition_geometry.config.system if acquisition_geometry.geom_type == 'cone': ag_temp = AcquisitionGeometry.create_Cone3D([*sys.source.position,0], [*sys.detector.position,0], [*sys.detector.direction_x,0],[0,0,1],[*sys.rotation_axis.position,0],[0,0,1]) else: ag_temp = AcquisitionGeometry.create_Parallel3D([*sys.ray.direction,0], [*sys.detector.position,0], [*sys.detector.direction_x,0],[0,0,1],[*sys.rotation_axis.position,0],[0,0,1]) ag_temp.config.panel = acquisition_geometry.config.panel ag_temp.set_angles(acquisition_geometry.angles) ag_temp.set_labels(['vertical', *acquisition_geometry.dimension_labels]) self.acquisition_geometry = ag_temp elif acquisition_geometry.channels > 1: self.ndim = 3 self.acquisition_geometry = acquisition_geometry.get_slice(channel=0) else: self.acquisition_geometry = acquisition_geometry self.ndim = 3 if image_geometry is None: self.image_geometry=self.acquisition_geometry.get_ImageGeometry() else: self.image_geometry = image_geometry len1 = self.acquisition_geometry.config.panel.num_pixels[0] * self.acquisition_geometry.config.panel.pixel_size[0] len2 = self.acquisition_geometry.config.panel.num_pixels[1] * self.acquisition_geometry.config.panel.pixel_size[1] self.scale = max(len1,len2)/5 self.handles = [] self.labels = [] def draw(self, elev=35, azim=35, view_distance=10, grid=False, figsize=(10,10), fontsize=10): self.fig = plt.figure(figsize=figsize) self.ax = self.fig.add_subplot(111, projection='3d') self.text_options = { 'horizontalalignment': 'center', 'verticalalignment': 'center', 'fontsize': fontsize } self.display_world() if self.acquisition_geometry.geom_type == 'cone': self.display_source() else: self.display_ray() self.display_object() self.display_detector() if grid is False: self.ax.set_axis_off() self.ax.view_init(elev=elev, azim=azim) self.ax.dist = view_distance #to force aspect ratio 1:1:1 world_limits = self.ax.get_w_lims() self.ax.set_box_aspect((world_limits[1]-world_limits[0],world_limits[3]-world_limits[2],world_limits[5]-world_limits[4])) l = self.ax.plot(np.NaN, np.NaN, '-', color='none', label='')[0] for i in range(3): self.handles.insert(2,l) self.labels.insert(2,'') with warnings.catch_warnings(): warnings.simplefilter("ignore") self.ax.legend(self.handles, self.labels, loc='upper left', bbox_to_anchor= (0, 1), ncol=3, borderaxespad=0, frameon=False,fontsize=self.text_options.get('fontsize')) self.fig.set_tight_layout(True) self.fig.set_facecolor('w') #plt.show() creates a new figure so we save a copy to return fig2 = plt.gcf() plt.show() return fig2 def display_world(self): self.ax.set_xlabel('X axis') self.ax.set_ylabel('Y axis') if self.ndim == 3: self.ax.set_zlabel('Z axis') else: self.ax.set_zticks([]) #origin and coordinate frame Oo = np.zeros(3) self.ax.scatter3D(*Oo, marker='o', alpha=1,color='k',lw=1) h = mlines.Line2D([], [], color='k',linestyle='solid', markersize=12, label='world coordinate system') labels = ['$x$','$y$','$z$'] for i in range(self.ndim): axis = np.zeros(3) axis[i] = 1 * self.scale a = _Arrow3D(*zip(Oo,axis*2), mutation_scale=20,lw=1, arrowstyle="->", color="k") self.ax.add_artist(a) self.ax.text(*(axis*2.2),labels[i], **self.text_options) self.handles.append(h) self.labels.append(h.get_label()) def detector_vertex(self): # detector corners det_size = (np.array(self.acquisition_geometry.config.panel.num_pixels) * np.array(self.acquisition_geometry.config.panel.pixel_size))/2 det_rows_dir = self.acquisition_geometry.config.system.detector.direction_x if self.ndim == 3: det_v = self.acquisition_geometry.config.system.detector.direction_y * det_size[1] det_h = det_rows_dir * det_size[0] rt = det_h + det_v + self.acquisition_geometry.config.system.detector.position lt = -det_h + det_v + self.acquisition_geometry.config.system.detector.position lb = -det_h - det_v + self.acquisition_geometry.config.system.detector.position rb = det_h - det_v + self.acquisition_geometry.config.system.detector.position return [rb, lb, lt, rt] else: det_h = det_rows_dir * det_size[0] r = det_h + self.acquisition_geometry.config.system.detector.position l = -det_h + self.acquisition_geometry.config.system.detector.position return [r, l] def display_detector(self): do = self.acquisition_geometry.config.system.detector.position det = self.detector_vertex() #mark data origin if 'right' in self.acquisition_geometry.config.panel.origin: if self.ndim==2 or 'bottom' in self.acquisition_geometry.config.panel.origin: pix0 = det[0] else: pix0 = det[3] else: if self.ndim==2 or 'bottom' in self.acquisition_geometry.config.panel.origin: pix0 = det[1] else: pix0 = det[2] det_rows_dir = self.acquisition_geometry.config.system.detector.direction_x x = _Arrow3D(*zip(do, self.scale * det_rows_dir + do), mutation_scale=20,lw=1, arrowstyle="-|>", color="b") self.ax.add_artist(x) self.ax.text(*(1.2 * self.scale * det_rows_dir + do),r'$D_x$', **self.text_options) if self.ndim == 3: det_col_dir = self.acquisition_geometry.config.system.detector.direction_y y = _Arrow3D(*zip(do, self.scale * det_col_dir + do), mutation_scale=20,lw=1, arrowstyle="-|>", color="b") self.ax.add_artist(y) self.ax.text(*(1.2 * self.scale * det_col_dir + do),r'$D_y$', **self.text_options) handles=[ self.ax.scatter3D(*do, marker='o', alpha=1,color='b',lw=1, label='detector position'), mlines.Line2D([], [], color='b',linestyle='solid', markersize=12, label='detector direction'), self.ax.plot3D(*zip(*det, det[0]), color='b',ls='dotted',alpha=1, label='detector')[0], self.ax.scatter3D(*pix0, marker='x', alpha=1,color='b',lw=1,s=50, label='data origin (pixel 0)'), ] for x in handles: self.handles.append(x) self.labels.append(x.get_label()) def display_object(self): ro = self.acquisition_geometry.config.system.rotation_axis.position h0 = self.ax.scatter3D(*ro, marker='o', alpha=1,color='r',lw=1,label='rotation axis position') self.handles.append(h0) self.labels.append(h0.get_label()) if self.ndim == 3: # rotate axis arrow r1 = ro + self.acquisition_geometry.config.system.rotation_axis.direction * self.scale * 2 arrow3 = _Arrow3D(*zip(ro,r1), mutation_scale=20,lw=1, arrowstyle="-|>", color="r") self.ax.add_artist(arrow3) a = self.acquisition_geometry.config.system.rotation_axis.direction # draw reco x = np.array([self.image_geometry.get_min_x(), self.image_geometry.get_max_x()]) y = np.array([self.image_geometry.get_min_y(), self.image_geometry.get_max_y()]) z = np.array([self.image_geometry.get_min_z(), self.image_geometry.get_max_z()]) combos = [ ((x[0],y[0],z[0]),(x[0],y[1],z[0])), ((x[0],y[1],z[0]),(x[1],y[1],z[0])), ((x[1],y[1],z[0]),(x[1],y[0],z[0])), ((x[1],y[0],z[0]),(x[0],y[0],z[0])), ((x[0],y[0],z[1]),(x[0],y[1],z[1])), ((x[0],y[1],z[1]),(x[1],y[1],z[1])), ((x[1],y[1],z[1]),(x[1],y[0],z[1])), ((x[1],y[0],z[1]),(x[0],y[0],z[1])), ((x[0],y[0],z[0]),(x[0],y[0],z[1])), ((x[0],y[1],z[0]),(x[0],y[1],z[1])), ((x[1],y[1],z[0]),(x[1],y[1],z[1])), ((x[1],y[0],z[0]),(x[1],y[0],z[1])), ] if np.allclose(a,[0,0,1]): axis_rotation = np.eye(3) elif np.allclose(a,[0,0,-1]): axis_rotation = np.eye(3) axis_rotation[1][1] = -1 axis_rotation[2][2] = -1 else: vx = np.array([[0, 0, -a[0]], [0, 0, -a[1]], [a[0], a[1], 0]]) axis_rotation = np.eye(3) + vx + vx.dot(vx) * 1 / (1 + a[2]) rotation_matrix = np.matrix.transpose(axis_rotation) count = 0 for x in combos: s = rotation_matrix.dot(np.asarray(x[0]).reshape(3,1)) e = rotation_matrix.dot(np.asarray(x[1]).reshape(3,1)) x_data = float(s[0]) + ro[0], float(e[0]) + ro[0] y_data = float(s[1]) + ro[1], float(e[1]) + ro[1] z_data = float(s[2]) + ro[2], float(e[2]) + ro[2] self.ax.plot3D(x_data,y_data,z_data, color="r",ls='dotted',alpha=1) if count == 0: vox0=(x_data[0],y_data[0],z_data[0]) count+=1 else: # draw square x = [self.image_geometry.get_min_x(), self.image_geometry.get_max_x()] y = [self.image_geometry.get_min_y(), self.image_geometry.get_max_y()] vertex = np.array([(x[0],y[0],0),(x[0],y[1],0),(x[1],y[1],0),(x[1],y[0],0)]) + ro self.ax.plot3D(*zip(*vertex, vertex[0]), color='r',ls='dotted',alpha=1) vox0=vertex[0] rotation_matrix = np.eye(3) #rotation direction points = 36 x = [None]*points y = [None]*points z = [None]*points for i in range(points): theta = i * (np.pi * 1.8) /36 point_i = np.array([np.sin(theta),-np.cos(theta),0]).reshape(3,1) point_rot = -self.scale*0.5*rotation_matrix.dot(point_i) x[i] = float(point_rot[0] + ro[0]) y[i] = float(point_rot[1] + ro[1]) z[i] = float(point_rot[2] + ro[2]) self.ax.plot3D(x,y,z, color='r',ls="dashed",alpha=1) arrow4 = _Arrow3D(x[-2:],y[-2:],z[-2:],mutation_scale=20,lw=1, arrowstyle="-|>", color="r") self.ax.add_artist(arrow4) handles = [ mlines.Line2D([], [], color='r',linestyle='solid', markersize=12, label='rotation axis direction'), mlines.Line2D([], [], color='r',linestyle='dotted', markersize=15, label='image geometry'), self.ax.scatter3D(*vox0, marker='x', alpha=1,color='r',lw=1,s=50, label='data origin (voxel 0)'), mlines.Line2D([], [], color='r',linestyle='dashed', markersize=12, label=r'rotation direction $\theta$') ] for x in handles: self.handles.append(x) self.labels.append(x.get_label()) def display_source(self): so = self.acquisition_geometry.config.system.source.position det = self.detector_vertex() for i in range(len(det)): self.ax.plot3D(*zip(so,det[i]), color='#D4BD72',ls="dashed",alpha=0.4) self.ax.plot3D(*zip(so,self.acquisition_geometry.config.system.detector.position), color='#D4BD72',ls="solid",alpha=1)[0], h0 = self.ax.scatter3D(*so, marker='*', alpha=1,color='#D4BD72',lw=1, label='source position', s=100) self.handles.append(h0) self.labels.append(h0.get_label()) def display_ray(self): det = self.detector_vertex() det.append(self.acquisition_geometry.config.system.detector.position) dist = np.sqrt(np.sum(self.acquisition_geometry.config.system.detector.position**2))*2 if dist < 0.01: dist = self.acquisition_geometry.config.panel.num_pixels[0] * self.acquisition_geometry.config.panel.pixel_size[0] rays = det - self.acquisition_geometry.config.system.ray.direction*dist for i in range(len(rays)): h0 = self.ax.plot3D(*zip(rays[i],det[i]), color='#D4BD72',ls="dashed",alpha=0.4, label='ray direction')[0] arrow = _Arrow3D(*zip(rays[i],rays[i]+self.acquisition_geometry.config.system.ray.direction*self.scale ),mutation_scale=20,lw=1, arrowstyle="-|>", color="#D4BD72") self.ax.add_artist(arrow) self.handles.append(h0) self.labels.append(h0.get_label())
[docs] class show_geometry(show_base): ''' Displays a schematic of the acquisition geometry for 2D geometries elevation and azimuthal cannot be changed Parameters ---------- acquisition_geometry: AcquisitionGeometry CIL acquisition geometry image_geometry: ImageGeometry, optional CIL image geometry elevation: float Camera elevation in degrees, 3D geometries only, default=20 azimuthal: float Camera azimuthal in degrees, 3D geometries only, default=-35 view_distance: float Camera view distance, default=10 grid: boolean Show figure axis, default=False figsize: tuple (x, y) Set figure size (inches), default (10,10) fontsize: int Set fontsize, default 10 Note ---- The figure can be saved using the `save` method . i.e: >>> fig = show_geometry(geometry) >>> fig.save('/path/to/output.png') ''' def __init__(self,acquisition_geometry, image_geometry=None, elevation=20, azimuthal=-35, view_distance=10, grid=False, figsize=(10,10), fontsize=10): if AcquisitionType.DIM2 & acquisition_geometry.dimension: elevation = 90 azimuthal = 0 self.display = _ShowGeometry(acquisition_geometry, image_geometry) self.figure = self.display.draw(elev=elevation, azim=azimuthal, view_distance=view_distance, grid=grid, figsize=figsize, fontsize=fontsize)