Source code for cil.framework.labels

#  Copyright 2024 United Kingdom Research and Innovation
#  Copyright 2024 The University of Manchester
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
from enum import Enum, Flag as _Flag, auto, unique
try:
    from enum import EnumType
except ImportError: # Python<3.11
    from enum import EnumMeta as EnumType


class _StrEnumMeta(EnumType):
    """Python<3.12 requires this in a metaclass (rather than directly in StrEnum)"""
    def __contains__(self, item: str) -> bool:
        try:
            key = item.upper()
        except (AttributeError, TypeError):
            return False
        return key in self.__members__ or item in self.__members__.values()


@unique
class StrEnum(str, Enum, metaclass=_StrEnumMeta):
    """Case-insensitive StrEnum"""
    @classmethod
    def _missing_(cls, value: str):
        return cls.__members__.get(value.upper(), None)

    def __eq__(self, value: str) -> bool:
        try:
            value = self.__class__[value.upper()]
        except (KeyError, ValueError, AttributeError):
            pass
        return super().__eq__(value)

    def __hash__(self) -> int:
        """consistent hashing for dictionary keys"""
        return hash(self.value)

    # compatibility with Python>=3.11 `enum.StrEnum`
    __str__ = str.__str__
    __format__ = str.__format__

    @staticmethod
    def _generate_next_value_(name: str, start, count, last_values) -> str:
        return name.lower()


class Backend(StrEnum):
    """
    Available backends for CIL.

    Examples
    --------
    ```
    FBP(data, backend=Backend.ASTRA)
    FBP(data, backend="astra")
    ```
    """
    ASTRA = auto()
    TIGRE = auto()
    CIL = auto()


class _DimensionBase:
    @classmethod
    def _default_order(cls, engine: str) -> tuple:
        raise NotImplementedError

    @classmethod
    def get_order_for_engine(cls, engine: str, geometry=None) -> tuple:
        """
        Returns the order of dimensions for a specific engine and geometry.

        Parameters
        ----------
        geometry: ImageGeometry | AcquisitionGeometry
            If unspecified, the default order is returned.
        """
        order = cls._default_order(engine)
        if geometry is None:
            return order
        return tuple(label for label in order if label in geometry.dimension_labels)

    @classmethod
    def check_order_for_engine(cls, engine: str, geometry) -> bool:
        """
        Returns True iff the order of dimensions is correct for a specific engine and geometry.

        Parameters
        ----------
        geometry: ImageGeometry | AcquisitionGeometry

        Raises
        ------
        ValueError if the order of dimensions is incorrect.
        """
        order_requested = cls.get_order_for_engine(engine, geometry)
        if order_requested == tuple(geometry.dimension_labels):
            return True
        raise ValueError(
            f"Expected dimension_label order {order_requested},"
            f" got {tuple(geometry.dimension_labels)}."
            f" Try using `data.reorder('{engine}')` to permute for {engine}")


[docs] class ImageDimension(_DimensionBase, StrEnum): """ Available dimension labels for image data. Examples -------- >>> data.reorder([ImageDimension.HORIZONTAL_X, ImageDimension.VERTICAL]) >>> data.reorder(["horizontal_x", "vertical"]) """ CHANNEL = auto() VERTICAL = auto() HORIZONTAL_X = auto() HORIZONTAL_Y = auto() @classmethod def _default_order(cls, engine: str) -> tuple: engine = Backend(engine) orders = { Backend.ASTRA: (cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X), Backend.TIGRE: (cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X), Backend.CIL: (cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X)} return orders[engine]
[docs] class AcquisitionDimension(_DimensionBase, StrEnum): """ Available dimension labels for acquisition data. Examples -------- >>> data.reorder([AcquisitionDimension.CHANNEL, AcquisitionDimension.ANGLE, AcquisitionDimension.HORIZONTAL]) >>> data.reorder(["channel", "angle", "horizontal"]) """ CHANNEL = auto() ANGLE = auto() VERTICAL = auto() HORIZONTAL = auto() @classmethod def _default_order(cls, engine: str) -> tuple: engine = Backend(engine) orders = { Backend.ASTRA: (cls.CHANNEL, cls.VERTICAL, cls.ANGLE, cls.HORIZONTAL), Backend.TIGRE: (cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL), Backend.CIL: (cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL)} return orders[engine]
[docs] class FillType(StrEnum): """ Available fill types for image data. Attributes ---------- RANDOM: Fill with random values. RANDOM_INT: Fill with random integers. Examples -------- >>> data.fill(FillType.RANDOM) >>> data.fill("random") """ RANDOM = auto() RANDOM_INT = auto()
[docs] class AngleUnit(StrEnum): """ Available units for angles. Examples -------- >>> data.geometry.set_angles(angle_data, angle_units=AngleUnit.DEGREE) >>> data.geometry.set_angles(angle_data, angle_units="degree") """ DEGREE = auto() RADIAN = auto()
class _FlagMeta(EnumType): """Python<3.12 requires this in a metaclass (rather than directly in Flag)""" def __contains__(self, item) -> bool: return item.upper() in self.__members__ if isinstance(item, str) else super().__contains__(item) @unique class Flag(_Flag, metaclass=_FlagMeta): """Case-insensitive Flag""" @classmethod def _missing_(cls, value): return cls.__members__.get(value.upper(), None) if isinstance(value, str) else super()._missing_(value) def __eq__(self, value: str) -> bool: return super().__eq__(self.__class__(value.upper()) if isinstance(value, str) else value)
[docs] class AcquisitionType(Flag): """ Available acquisition types & dimensions. WARNING: It's best to use strings rather than integers to initialise. >>> AcquisitionType(3) == AcquisitionType(2 | 1) == AcquisitionType.CONE|PARALLEL != AcquisitionType('3D') Attributes ---------- PARALLEL: Parallel beam. CONE: Cone beam. DIM2: 2D acquisition. DIM3: 3D acquisition. """ PARALLEL = auto() CONE = auto() DIM2 = auto() DIM3 = auto()
[docs] def validate(self): """ Check if the geometry and dimension types are allowed """ assert len(self.dimension) < 2, f"{self} must be 2D xor 3D" assert len(self.geometry) < 2, f"{self} must be parallel xor cone beam" return self
@property def dimension(self): """ Returns the label for the dimension type """ return self & (self.DIM2 | self.DIM3) @property def geometry(self): """ Returns the label for the geometry type """ return self & (self.PARALLEL | self.CONE) @classmethod def _missing_(cls, value): """2D/3D aliases""" if isinstance(value, str): value = {'2D': 'DIM2', '3D': 'DIM3'}.get(value.upper(), value) return super()._missing_(value) def __str__(self) -> str: """2D/3D special handling""" return '2D' if self == self.DIM2 else '3D' if self == self.DIM3 else (self.name or super().__str__()) def __hash__(self) -> int: """consistent hashing for dictionary keys""" return hash(self.value) # compatibility with Python>=3.11 `enum.Flag` def __len__(self) -> int: return bin(self.value).count('1')