"""
Primary module for ETSpy package.
Contains the TomoStack class and its methods.
"""
import copy
import inspect
import logging
from abc import ABC
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional, Union, cast
try:
from typing import Self # type: ignore
except ImportError: # pragma: no cover
# required to support Pyton 3.10 since Self was added in 3.11
from typing_extensions import Self
import astra # noqa: I001
import matplotlib as mpl
import numpy as np
import pylab as plt
from hyperspy._signals.signal1d import Signal1D
from hyperspy._signals.signal2d import Signal2D
from hyperspy.signal import BaseSignal, SpecialSlicersSignal
from matplotlib.figure import Figure
from scipy import ndimage
from skimage import transform
from traits.api import Undefined
from etspy import AlignmentMethod, AlignmentMethodType, FbpMethodType, ReconMethodType
from etspy import _format_choices as _fmt
from etspy import _get_literal_hint_values as _get_lit
from etspy import align, recon
if TYPE_CHECKING:
from types import FrameType
from hyperspy.axes import UniformDataAxis as Uda
from hyperspy.misc.utils import DictionaryTreeBrowser as Dtb
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]
class MismatchedTiltError(ValueError):
"""
Error for when number of tilts in signal does not match tilt dimension.
Group
-----
signals
Order
-----
6
"""
def __init__(self, num_tilts, tilt_dimension):
"""Create a MismatchedTiltError."""
super().__init__(
f"Number of tilts ({num_tilts}) does not match "
f"the tilt dimension of the data array ({tilt_dimension})",
)
# The following private classes are used to allow type hinting on inav and isig
class _TomoShiftSlicer(SpecialSlicersSignal):
def __getitem__(self, slices, out=None) -> "TomoShifts":
return super().__getitem__(slices, out=out)
class _TomoTiltSlicer(SpecialSlicersSignal):
def __getitem__(self, slices, out=None) -> "TomoTilts":
return super().__getitem__(slices, out=out)
class _TomoStackSlicer(SpecialSlicersSignal):
def __getitem__(self, slices, out=None) -> "TomoStack":
# if self.isNavigation is False, we're doing isig,
# which should only be allowed if both axes have a size
# greater than one
if not self.isNavigation:
new_slices = []
# slices will be a tuple of int (single pixel-based slice),
# float (single unit-based slice), or slice class (range).
# since TomoStacks are Signal2Ds, it should always be length 2
for i, s in enumerate(slices):
if isinstance(s, (int, float)):
# don't allow a single slice
scale = self.obj.axes_manager.signal_axes[i].scale
if isinstance(s, int):
# if integer, set scale to 1 to just add one pixel
scale = 1
new_slices.append(slice(s, s + scale, None))
ax_name = self.obj.axes_manager.signal_axes[i].name
logger.warning(
"Slicing a TomoStack signal axis with a single pixel "
'is not supported. Returning a single pixel on the "%s" '
"axis instead",
ax_name,
)
else:
new_slices.append(s)
slices = tuple(new_slices)
return super().__getitem__(slices, out=out)
class _RecStackSlicer(SpecialSlicersSignal):
def __getitem__(self, slices, out=None) -> "RecStack":
return super().__getitem__(slices, out=out)
[docs]
class TomoShifts(Signal1D):
"""
Create a ``TomoShifts`` instance to store image shift values of a tomography stack.
This class is used to enforce dimension sizes and provide customized slicing
compared to a standard HyperSpy signal.
Group
-----
signals
Order
-----
4
"""
_signal_type = "TomoShifts"
_signal_dimension = 1
[docs]
def __init__(self, data, *args, **kwargs):
"""
Create a TomoShifts signal instance.
``TomoShifts`` is a specialized sub-class of
:py:class:`~hyperspy.api.signals.Signal1D` used to hold information
about the `x` and `y` image shifts in a :py:class:`~etspy.base.TomoStack`.
Parameters
----------
data
The signal data. Can be provided as either a HyperSpy
:py:class:`hyperspy.api.signals.Signal1D` or a Numpy array
of the shape `(ntilts, 2)`.
args
Additional non-keyword arguments passed to the
:py:class:`~hyperspy.api.signals.Signal1D` constructor.
kwargs
Additional keyword arguments passed to the
:py:class:`hyperspy.api.signals.Signal1D` constructor.
Raises
------
ValueError
If the data provided does not result in a Signal with signal shape (2,).
"""
super().__init__(data, *args, **kwargs)
if self.axes_manager.signal_shape != (2,):
msg = (
f"Shift values must have a signal shape of (2,), "
f"but was {self.axes_manager.signal_shape}"
)
raise ValueError(msg)
self.inav = _TomoShiftSlicer(self, isNavigation=True)
self.isig = _TomoShiftSlicer(self, isNavigation=False)
def _slicer(self, slices, isNavigation=None, out=None): # noqa: N803
if not isNavigation:
f = cast("FrameType", inspect.currentframe())
co_names = cast("FrameType", f.f_back).f_code.co_names
if "_additional_slicing_targets" not in co_names:
# '_additional_slicing_targets' will only be present if call originated
# from the HyperSpy slicing module, meaning we reached this code by
# calling `s.isig[]` rather than `s.tilts.isig[]` We don't want to warn
# in this case, since it happens so frequently
logger.warning(
"TomoShifts does not support 'isig' slicing, as signal shape must "
"be (2,). Signal was returned with its original "
"shape: %s",
self,
)
return self
return super()._slicer(slices, isNavigation, out)
[docs]
class TomoTilts(Signal1D):
"""
Create a ``TomoTilts`` instance, used to hold the tilt values of a tomography stack.
This class is used to enforce dimension sizes and provide customized slicing
compared to a standard HyperSpy signal.
Group
-----
signals
Order
-----
5
"""
_signal_type = "TomoTilts"
_signal_dimension = 1
[docs]
def __init__(self, data, *args, **kwargs):
"""
Create a TomoTilts signal instance.
``TomoTilts`` is a specialized sub-class of
:py:class:`~hyperspy.api.signals.Signal1D` used to hold information
about the tilt values of each projection in a :py:class:`~etspy.base.TomoStack`.
Parameters
----------
data
The signal data. Can be provided as either a HyperSpy
:py:class:`hyperspy.api.signals.Signal1D` or a Numpy array
of the shape `(ntilts, 1)`.
args
Additional non-keyword arguments passed to the
:py:class:`~hyperspy.api.signals.Signal1D` constructor.
kwargs
Additional keyword arguments passed to the
:py:class:`hyperspy.api.signals.Signal1D` constructor.
Raises
------
ValueError
If the data provided does not result in a Signal with signal shape (1,).
"""
super().__init__(data, *args, **kwargs)
if self.axes_manager.signal_shape != (1,):
msg = (
f"Tilt values must have a signal shape of (1,), "
f"but was {self.axes_manager.signal_shape}"
)
raise ValueError(msg)
self.inav = _TomoTiltSlicer(self, isNavigation=True)
self.isig = _TomoTiltSlicer(self, isNavigation=False)
def _slicer(self, slices, isNavigation=None, out=None): # noqa: N803
if not isNavigation:
f = cast("FrameType", inspect.currentframe())
co_names = cast("FrameType", f.f_back).f_code.co_names
if "_additional_slicing_targets" not in co_names:
# '_additional_slicing_targets' will only be present if call originated
# from the HyperSpy slicing module, meaning we reached this code by
# calling `s.isig[]` rather than `s.tilts.isig[]` We don't want to warn
# in this case, since it happens so frequently
logger.warning(
"TomoTilts does not support 'isig' slicing, as signal shape must "
"be (1,). Signal was returned with its original shape: %s",
self,
)
return self
return super()._slicer(slices, isNavigation, out)
[docs]
class CommonStack(Signal2D, ABC):
"""
An abstract base class for tomography data.
.. abstract::
This class is intended to be subclassed (*e.g.* by
:py:class:`~etspy.base.TomoStack` and :py:class:`~etspy.base.RecStack`) and
should not be instantiated directly. Doing so will raise a
:py:exc:`NotImplementedError`.
All arguments (other than ``tilts`` and ``shifts``) are passed to the
:py:class:`~hyperspy.api.signals.Signal2D`
constructor and should be used as documented for that method.
Group
-----
signals
Order
-----
3
"""
_signal_type = "CommonStack"
_signal_dimension = 2
[docs]
def __init__(
self,
data: Union[np.ndarray, Signal2D],
*args,
**kwargs,
):
"""
Create an ETSpy signal instance.
Parameters
----------
data
The signal data. Can be provided as either a HyperSpy
:py:class:`hyperspy.api.signals.Signal2D` or a Numpy array
of the shape `(tilt, y, x)`.
args
Additional non-keyword arguments passed to the
:py:class:`~hyperspy.api.signals.Signal2D` constructor
kwargs
Additional keyword arguments passed to the
:py:class:`hyperspy.api.signals.Signal2D` constructor
Raises
------
NotImplementedError
:py:class:`~etspy.base.CommonStack` is not intended to be used directly.
One of its sub-classes (:py:class:`~etspy.base.TomoStack` or
:py:class:`~etspy.base.RecStack`) should be used instead.
"""
if type(self) is CommonStack:
msg = (
"CommonStack should not be instantiated directly. Use one of its "
"sub-classes instead (TomoStack or RecStack)"
)
raise NotImplementedError(msg)
super().__init__(data, *args, **kwargs)
[docs]
def plot(self, navigator: str = "slider", *args, **kwargs):
"""
Override of plot function to set default HyperSpy navigator to 'slider'.
Any other arguments (keyword and non-keyword) are passed to
:py:meth:`hyperspy.api.signals.Signal2D.plot`
"""
super().plot(navigator=navigator, *args, **kwargs) # noqa: B026
[docs]
def change_data_type(self, dtype: Union[str, np.dtype]):
"""
Change the data type of a stack.
Use instead of the inherited change_dtype function of Hyperspy which results in
conversion of the Stack to a Signal2D.
Parameters
----------
dtype
A string that represents a NumPy data type, or a specific data type
"""
self.data = self.data.astype(dtype)
[docs]
def invert(self) -> Self:
"""
Create a copy of a Stack with inverted contrast levels.
Returns
-------
inverted : Self
Copy of the input stack with contrast inverted
Examples
--------
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> s_inverted = stack.invert()
"""
maxvals = self.data.max(2).max(1)
maxvals = maxvals.reshape([self.data.shape[0], 1, 1])
minvals = self.data.min(2).min(1)
minvals = minvals.reshape([self.data.shape[0], 1, 1])
ranges = maxvals - minvals
inverted = self.deepcopy()
inverted.data = inverted.data - np.reshape(
inverted.data.mean(2).mean(1),
[self.data.shape[0], 1, 1],
)
inverted.data = (inverted.data - minvals) / ranges
inverted.data = inverted.data - 1
inverted.data = np.sqrt(inverted.data**2)
inverted.data = (inverted.data * ranges) + minvals
return inverted
[docs]
def normalize(self, width: int = 3) -> Self:
"""
Create a copy of a stack with normalized contrast levels.
Parameters
----------
width
Number of standard deviations from the mean to set
as maximum intensity level.
Returns
-------
normalized : Self
Copy of the input stack with intensities normalized
Examples
--------
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> s_normalized = stack.normalize()
"""
normalized = self.deepcopy()
minvals = np.reshape(
(normalized.data.min(2).min(1)),
[self.data.shape[0], 1, 1],
)
normalized.data = normalized.data - minvals
meanvals = np.reshape(
(normalized.data.mean(2).mean(1)),
[self.data.shape[0], 1, 1],
)
stdvals = np.reshape(
(normalized.data.std(2).std(1)),
[self.data.shape[0], 1, 1],
)
normalized.data = normalized.data / (meanvals + width * stdvals)
return normalized
[docs]
def save_raw(self, filename: Optional[Union[str, Path]] = None) -> Path:
"""
Save Stack data as a .raw/.rpl file pair.
Parameters
----------
filename
Name of file to receive data. If not specified, the metadata will
be used. Data dimensions and data type will be appended.
Returns
-------
filename : pathlib.Path
The path to the file that was saved
"""
datashape = self.data.shape
if filename is None:
filename = Path(str(cast("Dtb", self.metadata.General).title))
elif isinstance(filename, str):
filename = Path(filename)
filename = filename.parent / (
filename.stem + f"_{datashape[0]}x"
f"{datashape[1]}x"
f"{datashape[2]}_"
f"{self.data.dtype.name}.rpl"
)
self.save(filename)
return filename
[docs]
def stats(self):
"""Print some basic statistics about Stack data."""
print(f"Mean: {self.data.mean():.1f}") # noqa: T201
print(f"Std: {self.data.std():.2f}") # noqa: T201
print(f"Max: {self.data.max():.1f}") # noqa: T201
print(f"Min: {self.data.min():.1f}\n") # noqa: T201
[docs]
def trans_stack(
self,
xshift: float = 0.0,
yshift: float = 0.0,
angle: float = 0.0,
interpolation: Literal["linear", "cubic", "nearest", "none"] = "linear",
) -> Self:
"""
Create a copy of a Stack, transformed using the ``skimage`` Affine transform.
Parameters
----------
xshift
Number of pixels by which to shift in the X dimension
yshift
Number of pixels by which to shift the stack in the Y dimension
angle
Angle in degrees by which to rotate the stack about the X-Y plane
interpolation
Mode of interpolation to employ. Must be either ``'linear'``,
``'cubic'``, ``'nearest'`` or ``'none'``. Note that ``'nearest'``
and ``'none'`` are equivalent. Default is ``'linear'``.
Returns
-------
out : Self
Transformed copy of the input stack
Examples
--------
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> xshift = 10.0
>>> yshift = 3.5
>>> angle = -15.2
>>> transformed = stack.trans_stack(xshift, yshift, angle)
>>> transformed
<TomoStack, title: , dimensions: (77|256, 256)>
"""
transformed = self.deepcopy()
if interpolation.lower() in ["linear", "cubic"] and np.issubdtype(
transformed.data.dtype,
np.integer,
):
transformed.data = transformed.data.astype("float32")
logger.debug("Data converted to float prior to transformation")
theta = np.pi * angle / 180.0
center_y, center_x = np.array(
np.array(transformed.data.shape[1:]) / 2,
dtype=np.float32,
)
rot_mat = np.array(
[
[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1],
],
)
trans_mat = np.array([[1, 0, center_x], [0, 1, center_y], [0, 0, 1]])
rev_mat = np.array([[1, 0, -center_x], [0, 1, -center_y], [0, 0, 1]])
rotation_mat = np.dot(np.dot(trans_mat, rot_mat), rev_mat)
shift = np.array(
[[1, 0, np.float32(xshift)], [0, 1, np.float32(-yshift)], [0, 0, 1]],
)
full_transform = np.dot(shift, rotation_mat)
tform = transform.AffineTransform(full_transform)
if interpolation.lower() == "nearest" or interpolation.lower() == "none":
interpolation_order = 0
elif interpolation.lower() == "linear":
interpolation_order = 1
elif interpolation.lower() == "cubic":
interpolation_order = 3
else:
msg = (
f'Invalid interpolation method "{interpolation}". Must be one of '
f"{_fmt(_get_lit(self.trans_stack, 'interpolation'))}."
)
raise ValueError(msg)
for i in range(self.data.shape[0]):
transformed.data[i, :, :] = transform.warp(
transformed.data[i, :, :],
inverse_map=tform.inverse,
order=interpolation_order,
)
trans_tomo_meta = cast("Dtb", transformed.metadata.Tomography)
self_tomo_meta = cast("Dtb", self.metadata.Tomography)
trans_tomo_meta.xshift = cast("float", self_tomo_meta.xshift) + xshift
trans_tomo_meta.yshift = cast("float", self_tomo_meta.yshift) + yshift
trans_tomo_meta.tiltaxis = cast("float", self_tomo_meta.tiltaxis) + angle
return transformed
[docs]
class TomoStack(CommonStack):
"""
Create a TomoStack instance, used to represent tomographic tilt series data.
All arguments are passed to the :py:class:`~hyperspy.api.signals.Signal2D`
constructor and should be used as documented for that method.
Group
-----
signals
Order
-----
1
"""
_signal_type = "TomoStack"
_signal_dimension = 2
def _create_tomostack_from_signal(self, data, tilts, *args, **kwargs):
"""Create stack from HyperSpy signal (helper method for __init__)."""
if data.axes_manager.navigation_dimension == 1:
ntilts = data.axes_manager[0].size
else:
ntilts = data.axes_manager["Projections"].size
if (tilts is not None) and (ntilts != tilts.data.shape[0]):
raise MismatchedTiltError(tilts.data.shape[0], ntilts)
tomo_metadata = {
"Tomography": {
"cropped": False,
"tiltaxis": 0,
"xshift": 0,
"yshift": 0,
},
}
# metadata may already be present in data
if data.metadata.has_item("Tomography"):
# don't do anything if the signal already has Tomography metadata
pass
else:
# if not, create default one
data.metadata.add_dictionary(tomo_metadata)
metadata_dict = data.metadata.as_dictionary()
# metadata may be supplied in kwargs; if so, overwrite
# any existing metadata:
if "metadata" in kwargs:
metadata_dict = kwargs["metadata"]
# add default Tomo metadata if not present
if "Tomography" not in metadata_dict:
metadata_dict["Tomography"] = tomo_metadata["Tomography"]
# remove from kwargs so we don't supply it twice
del kwargs["metadata"]
# do similar check for original metadata
original_metadata_dict = data.original_metadata.as_dictionary()
if "original_metadata" in kwargs:
original_metadata_dict = kwargs["original_metadata"]
del kwargs["original_metadata"]
# similar for axes
axes_list = [x for _, x in sorted(data.axes_manager.as_dictionary().items())]
if "axes" in kwargs:
axes_list = kwargs["axes"]
del kwargs["axes"]
super().__init__(
data,
axes=axes_list,
metadata=metadata_dict,
original_metadata=original_metadata_dict,
*args, # noqa: B026
**kwargs,
)
def _fix_projection_axis(self, axis_number: int):
ax = cast("Uda", self.axes_manager[axis_number])
if ax.name in (Undefined, "z"):
ax.name = "Projections"
if ax.units == Undefined:
ax.units = "degrees"
del ax.scale
def _fix_frames_axis(self, axis_number: int):
ax = cast("Uda", self.axes_manager[axis_number])
if ax.name in (Undefined, "z"):
ax.name = "Frames"
if ax.units == Undefined:
ax.units = "images"
del ax.scale
def _create_tomostack_from_ndarray(self, data, tilts, *args, **kwargs):
"""Create stack from Numpy array (helper method for __init__)."""
ntilts = data.shape[0]
if (tilts is not None) and (ntilts != tilts.data.shape[0]):
raise MismatchedTiltError(tilts.data.shape[0], ntilts)
if "metadata" in kwargs and "Tomography" in kwargs["metadata"]:
tomo_metadata = kwargs["metadata"]["Tomography"]
else:
tomo_metadata = {
"cropped": False,
"tiltaxis": 0,
"xshift": 0,
"yshift": 0,
}
super().__init__(
data,
*args,
**kwargs,
)
self.metadata.add_node("Tomography")
cast("Dtb", self.metadata.Tomography).add_dictionary(tomo_metadata)
# need to handle navigation_dimension == 1 (normal case) or == 2 (multiframe)
if self.axes_manager.navigation_dimension == 0:
# single projection (rare case, but should allow)
pass
elif self.axes_manager.navigation_dimension == 1:
# normal case
self._fix_projection_axis(0)
elif self.axes_manager.navigation_dimension == 2: # noqa: PLR2004
# multiframe
self._fix_projection_axis(1)
self._fix_frames_axis(0)
else:
msg = (
"Invalid number of navigation dimensions for a TomoStack ("
f"{self.axes_manager.navigation_dimension}). Must be either 0, 1, or 2."
)
raise ValueError(msg)
signal_axes = cast("tuple[Uda, Uda]", self.axes_manager.signal_axes)
x_axis = signal_axes[0]
x_axis.name = "x"
if x_axis.units == Undefined:
x_axis.units = "pixels"
y_axis = signal_axes[1]
y_axis.name = "y"
if y_axis.units == Undefined:
y_axis.units = "pixels"
[docs]
def __init__(
self,
data: Union[np.ndarray, Signal2D],
tilts: Optional[TomoTilts] = None,
shifts: Optional[Union[TomoShifts, np.ndarray]] = None,
*args,
**kwargs,
):
"""
Create a TomoStack signal instance.
Parameters
----------
data
The signal data. Can be provided as either a HyperSpy
:py:class:`hyperspy.api.signals.Signal2D` or a Numpy array
of the shape `(tilt, y, x)`.
tilts
A :py:class:`~etspy.base.TomoTilts` containing the
tilt value (in degrees) for each projection in the stack.
The navigation dimension should match the navigation dimension of the
stack (one value per tilt image).
shifts
A :py:class:`~etspy.base.TomoShifts` or :py:class:`~numpy.ndarray`
containing the x/y image shift value (in pixels) for each projection in
the stack. A signal should have the same navigation dimension as the stack.
A Numpy array should have shape `(nav_size, 2)`. If ``None``, the ``shifts``
will be initialized to zero-valued signal. If shifts are supplied as an
:py:class:`~numpy.ndarray`, the Y-shifts (perpendicular to the tilt axis)
should be in the ``shifts[:, 0]`` position and X-shifts (parallel to the
tilt axis) in ``shifts[:, 1]``.
args
Additional non-keyword arguments passed to the
:py:class:`~hyperspy.api.signals.Signal2D` constructor
kwargs
Additional keyword arguments passed to the
:py:class:`hyperspy.api.signals.Signal2D` constructor
Raises
------
ValueError
If the ``tilts`` or ``shifts`` signals provided do not have the correct
dimensions.
"""
# copy axes and metadata if input was already a signal
if isinstance(data, Signal2D):
self._create_tomostack_from_signal(data, tilts, *args, **kwargs)
elif isinstance(data, np.ndarray):
self._create_tomostack_from_ndarray(data, tilts, *args, **kwargs)
self.shifts = shifts
self.tilts = tilts
if self.axes_manager.navigation_axes:
# normal single-frame mode:
if (
# only set axis information if it is undefined
self.axes_manager.navigation_dimension == 1
and (
self.axes_manager.navigation_axes[0].units == Undefined
or self.axes_manager.navigation_axes[0].units == ""
)
):
nav_ax_0 = cast("Uda", self.axes_manager.navigation_axes[0])
nav_ax_0.name = "Projections"
nav_ax_0.units = "degrees"
# multiframe (special SerialEM case)
elif self.axes_manager.navigation_dimension == 2: # noqa: PLR2004
nav_ax_0 = cast("Uda", self.axes_manager.navigation_axes[0])
nav_ax_1 = cast("Uda", self.axes_manager.navigation_axes[1])
nav_ax_0.name = (
"Frames" if nav_ax_0.name == Undefined else nav_ax_0.name
)
nav_ax_0.units = (
"images" if nav_ax_0.units == Undefined else nav_ax_0.units
)
nav_ax_1.name = (
"Projections" if nav_ax_1.name == Undefined else nav_ax_1.name
)
nav_ax_1.units = (
"degrees" if nav_ax_1.units == Undefined else nav_ax_1.units
)
if self.axes_manager.signal_axes:
cast("Uda", self.axes_manager.signal_axes[0]).name = "x"
cast("Uda", self.axes_manager.signal_axes[1]).name = "y"
# ensure that shifts and tilts will be sliced when the Signal is
self._additional_slicing_targets = [
"metadata.Signal.Noise_properties.variance",
"_shifts",
"_tilts",
]
self.inav = _TomoStackSlicer(self, isNavigation=True)
self.isig = _TomoStackSlicer(self, isNavigation=False)
def _check_array_shape(
self,
array: Union[TomoShifts, TomoTilts, np.ndarray],
mode: Literal["shifts", "tilts"],
) -> Union[TomoShifts, TomoTilts, np.ndarray]:
"""
Check if a signal or array shape is appropriate for the current stack.
Used by the shifts and tilts setter methods as a sanity check on the
size of the arrays that are provided.
"""
to_check = array.data if isinstance(array, BaseSignal) else array
signal_size = 2 if mode == "shifts" else 1
# if we have more than one navigation dimension, an ndarray should be
# of shape (N, M, 1|2) if the signal's navigation shape is (M, N | X)
# allow a numpy tilt array to be (*self.axes_manager.navigation_shape,), but if
# it is, reshape it to (*self.axes_manager.navigation_shape, 1)
if (
isinstance(array, np.ndarray)
and (mode == "tilts")
and (array.shape == (*self.axes_manager.navigation_shape[::-1],))
):
array = array.reshape((*self.axes_manager.navigation_shape[::-1], 1))
elif to_check.shape != (*self.axes_manager.navigation_shape[::-1], signal_size):
msg = (
f"Shape of {mode} array must be "
f"{(*self.axes_manager.navigation_shape[::-1], signal_size)} to match "
f"the navigation size of the stack (was {to_check.shape})"
)
raise ValueError(msg)
return array
[docs]
def shift_and_tilt_setter(
self,
mode: Literal["shifts", "tilts"],
value: Optional[Union[TomoShifts, TomoTilts, np.ndarray]],
) -> Union[TomoShifts, TomoTilts]:
"""
Set either ``self._tilts`` or ``self._shifts`` to an array.
This method is split out to reduce duplication of code between the
:py:attr:`~etspy.base.TomoStack.tilts` and
:py:attr:`~etspy.base.TomoStack.shifts`
property setter functions, since they have significant
overlap.
Parameters
----------
mode
Whether to work on the ``_shifts`` or the ``_tilts`` of the stack
value:
The values to set, as either an array, or
:py:class:`~etspy.base.TomoShifts`, or :py:class:`~etspy.base.TomoTilts`.
If ``None``, the values will be initialized to an array of zeros of the
appropriate shape.
Returns
-------
target : :py:class:`~etspy.base.TomoShifts` or :py:class:`~etspy.base.TomoTilts`
The signal that should be set as either the shifts or tilts property.
Raises
------
ValueError
If the ``value`` is not the correct shape for either the `shapes` or `tilts`
property
"""
signal_size = 2 if mode == "shifts" else 1
_cls = TomoShifts if mode == "shifts" else TomoTilts
if value is None:
# shifts should be shape (self.nav_size | 2), tilts (self.nav_size | 1)
# numpy arrays should be the inverse of the navigation shape, so if
# self.axes_manager.navigation shape is (N, M), then the data shape provided
# to TomoShifts or TomoTilts should be (M, N)
target = _cls(
data=np.zeros((*self.axes_manager.navigation_shape[::-1], signal_size)),
)
elif isinstance(value, np.ndarray):
value = self._check_array_shape(value, mode)
target = _cls(data=value)
else:
# value is already a Signal, so test the dimensions
value = cast(
"Union[TomoTilts, TomoShifts]",
self._check_array_shape(value, mode),
)
# Using the TomoTilts/TomoShifts constructor strips metadata and axis
# info, so copy it back:
target = cast("Union[TomoTilts, TomoShifts]", _cls(data=value))
target.metadata.add_dictionary(value.metadata.as_dictionary())
target.axes_manager.update_axes_attributes_from(
(*value.axes_manager.navigation_axes, *value.axes_manager.signal_axes),
("name", "offset", "scale", "units"),
)
if target.metadata.get_item("General.title") == "":
target.metadata.set_item("General.title", f"Image {mode[:-1]} values")
# set metadata in the case value was None or Numpy array:
if value is None or isinstance(value, np.ndarray):
target.metadata.set_item("General.title", f"Image {mode[:-1]} values")
# set the navigation axes to the same as the signal's:
target.axes_manager.update_axes_attributes_from(
self.axes_manager.navigation_axes,
("name", "offset", "scale", "units"),
)
if target.axes_manager.signal_axes:
tilt_sig_ax = target.axes_manager.signal_axes[0] # pyright: ignore[reportGeneralTypeIssues]
tilt_sig_ax.name = (
"Shift values (x/y)" if mode == "shifts" else "Tilt values"
)
tilt_sig_ax.units = "pixels" if mode == "shifts" else "degrees"
return target
@property
def shifts(self) -> TomoShifts:
"""
The stack's image shift values (in pixels).
A :py:class:`~etspy.base.TomoShifts` signal containing the
x/y image shift value (in pixels) for each projection in the stack.
Should have the same navigation dimension as the stack.
"""
return self._shifts
@shifts.setter
def shifts(self, new_shifts: Optional[Union[TomoShifts, np.ndarray]]):
self._shifts = cast(
"TomoShifts",
self.shift_and_tilt_setter("shifts", new_shifts),
)
@shifts.deleter
def shifts(self):
self._shifts = cast(
"TomoShifts",
self.shift_and_tilt_setter("shifts", np.zeros_like(self.shifts.data)),
)
@property
def tilts(self) -> TomoTilts:
"""
The stack's tilt values (in degrees).
A :py:class:`~etspy.base.TomoTilts` signal containing the
tilt value (in degrees) for each projection in the stack.
Should have the same navigation dimension as the stack.
"""
return self._tilts
@tilts.setter
def tilts(self, new_tilts: Optional[Union[TomoTilts, np.ndarray]]):
self._tilts = cast(
"TomoTilts",
self.shift_and_tilt_setter("tilts", new_tilts),
)
@tilts.deleter
def tilts(self):
self._tilts = cast(
"TomoTilts",
self.shift_and_tilt_setter("tilts", np.zeros_like(self.tilts.data)),
)
[docs]
def deepcopy(self):
"""
Return a "deep copy" of this Stack.
Uses the standard library's :func:`~copy.deepcopy` function. Note: this means
the underlying data structure will be duplicated in memory.
Overrides the :py:meth:`~hyperspy.api.signals.BaseSignal.deepcopy`
method to ensure the ``tilts`` and ``shifts`` properties are also copied
See Also
--------
:py:meth:`~etspy.base.TomoStack.copy`
"""
s = copy.deepcopy(self)
s.tilts = copy.deepcopy(self.tilts)
s.shifts = copy.deepcopy(self.shifts)
return s
[docs]
def copy(self):
"""
Return a "shallow copy" of this Stack.
Uses the standard library's :func:`~copy.copy` function. Note: this will
return a copy of the, Stack, but it will not duplicate the underlying
data in memory, and both Stacks will reference the same data.
Overrides the :py:meth:`~hyperspy.api.signals.BaseSignal.copy`
method to ensure the ``tilts`` and ``shifts`` properties are also copied
See Also
--------
:py:meth:`~etspy.base.TomoStack.deepcopy`
"""
s = copy.copy(self)
s.tilts = copy.copy(self.tilts)
s.shifts = copy.copy(self.shifts)
return s
[docs]
def plot_sinos(self, *args: tuple, **kwargs: dict):
"""
Plot the TomoStack in sinogram orientation.
Parameters
----------
args
Additional non-keyword arguments passed to
:py:meth:`~hyperspy.api.signals.Signal2D.plot`
kwargs
Additional keyword arguments passed to
:py:meth:`~hyperspy.api.signals.Signal2D.plot`
"""
Signal2D(
data=self.data,
axes=[v for k, v in self.axes_manager.as_dictionary().items()],
).swap_axes(1, 0).swap_axes(1, 2).plot(
navigator="slider",
*args, # noqa: B026
**kwargs,
)
[docs]
def remove_projections(self, projections: Optional[list] = None) -> "TomoStack":
"""
Return a copy of the TomoStack with certain projections removed from the series.
This method is primarily provided as a helper/alternative way to
modify the list of projections. It is recommended to use the
``.inav`` slicer of the :py:class:`~etspy.base.TomoStack` class instead.
Parameters
----------
projections
List of projection indices in integers to remove
Returns
-------
s_new : TomoStack
Copy of self with indicated projections removed
Raises
------
ValueError
If no projections are provided
"""
if projections is None:
msg = "No projections provided"
raise ValueError(msg)
nprojs = len(projections)
s_new = self.deepcopy()
s_ax = cast("Uda", s_new.axes_manager[0])
s_tilt_ax = cast("Uda", s_new.tilts.axes_manager[0])
s_shift_ax = cast("Uda", s_new.shifts.axes_manager[0])
for ax in (s_ax, s_tilt_ax, s_shift_ax):
ax.size -= nprojs
mask = np.ones(self.data.shape[0], dtype=bool)
mask[projections] = False
s_new.data = self.data[mask]
s_new.tilts.data = self.tilts.data[mask]
s_new.shifts.data = self.shifts.data[mask]
return s_new
[docs]
def save(
self,
filename=None,
overwrite=None,
extension=None,
file_format=None,
**kwargs,
):
"""
Save the signal in the specified format.
Overloads the HyperSpy :py:meth:`~hyperspy.api.signals.BaseSignal.save` method
so that tilts and shifts are written to metadata prior to saving. All arguments
are the same as :py:meth:`~hyperspy.api.signals.BaseSignal.save`, so please
consult that method's documentation for details.
"""
self.metadata.set_item("Tomography.tilts", self.tilts)
self.metadata.set_item("Tomography.shifts", self.shifts)
super().save(
filename=filename,
overwrite=overwrite,
extension=extension,
file_format=file_format,
**kwargs,
)
[docs]
def test_correlation(
self,
images: Optional[Union[list[int], tuple[int, int]]] = None,
) -> Figure:
"""
Test output of cross-correlation prior to alignment.
Parameters
----------
images
List of two numbers indicating which projections to cross-correlate.
If ``None``, the first two images will be used.
Returns
-------
fig : ~matplotlib.figure.Figure
Figure showing the results
"""
if not images:
images = [0, 1]
im1 = self.data[images[0], :, :]
im2 = self.data[images[1], :, :]
image_product = np.fft.fft2(im1) * np.fft.fft2(im2).conj()
cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
fig = plt.figure(figsize=(8, 3))
ax1 = plt.subplot(1, 3, 1)
ax2 = plt.subplot(1, 3, 2, sharex=ax1, sharey=ax1)
ax3 = plt.subplot(1, 3, 3)
ax1.imshow(im1, cmap="gray")
ax1.set_axis_off()
ax1.set_title("Reference image")
ax2.imshow(im2, cmap="gray")
ax2.set_axis_off()
ax2.set_title("Offset image")
ax3.imshow(cc_image.real, cmap="inferno")
ax3.set_axis_off()
ax3.set_title("Cross-correlation")
return fig
# TODO: allow a list of signals for 'other'
[docs]
def align_other(
self,
other: "TomoStack",
shift_type: Literal["interp", "fourier"] = "fourier",
cuda: bool = False,
) -> "TomoStack":
"""
Apply the alignment calculated for one dataset to another.
This will include the spatial registration, tilt axis, and tilt axis
shift if they have been previously calculated.
Parameters
----------
other
The tilt series which is to be aligned using the previously
calculated parameters. The data array in the TomoStack must be of
the same size as that in ``self.data``
shift_type
Image shifts can be applied using either interpolation via
scipy.ndimage.shift or via Fourier shift as implemented in
scipy.ndimage.fourier_shift. Must be either 'interp' or 'fourier'.
cuda
Whether or not to use CUDA-accelerated reconstruction algorithms.
Returns
-------
out : TomoStack
The result of applying the alignment to other
"""
# Check if any transformations have been applied to the current stack
no_shifts = np.all(self.shifts.data == 0)
tomo_meta = cast("Dtb", self.metadata.Tomography)
no_xshift = any(
[
tomo_meta.xshift is None,
tomo_meta.xshift == 0.0,
],
)
no_yshift = any(
[
tomo_meta.xshift is None,
tomo_meta.xshift == 0.0,
],
)
no_rotation = any(
[
tomo_meta.tiltaxis is None,
tomo_meta.tiltaxis == 0.0,
],
)
if all([no_shifts, no_xshift, no_yshift, no_rotation]):
msg = "No transformations have been applied to this stack"
raise ValueError(msg)
out = align.align_to_other(self, other, shift_type, cuda)
return out
[docs]
def filter(
self,
method: Literal["median", "bpf", "both", "sobel"] = "median",
size: int = 5,
taper: float = 0.1,
) -> "TomoStack":
"""
Apply one of several image filters to an entire TomoStack.
Parameters
----------
method
Type of filter to apply. Must be ``'median'``, ``'bpf'``, ``'both'``, or
``'sobel'``.
size
Size of filtering neighborhood.
taper
Fraction of image size to pad to the mean.
Returns
-------
filtered : TomoStack
Filtered copy of the input stack
Examples
--------
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> filtered = stack.filter(method='median')
"""
filtered = self.deepcopy()
if method == "median":
filtered.data = ndimage.median_filter(filtered.data, size=(1, size, size))
elif method == "sobel":
for i in range(filtered.data.shape[0]):
dx = ndimage.sobel(filtered.data[i, :, :], 0)
dy = ndimage.sobel(filtered.data[i, :, :], 1)
filtered.data[i, :, :] = np.hypot(dx, dy)
elif method == "both":
filtered.data = ndimage.median_filter(filtered.data, size=(1, size, size))
for i in range(filtered.data.shape[0]):
dx = ndimage.sobel(filtered.data[i, :, :], 0)
dy = ndimage.sobel(filtered.data[i, :, :], 1)
filtered.data[i, :, :] = np.hypot(dx, dy)
elif method == "bpf":
lp_freq = 0.1
hp_freq = 0.05
lp_sigma = 1.5
hp_sigma = 1.5
[nprojs, rows, cols] = self.data.shape
fft = np.fft.fftshift(np.fft.fft2(self.data))
x = (np.arange(0, cols) - np.fix(cols / 2)) / cols
y = (np.arange(0, rows) - np.fix(rows / 2)) / rows
xx, yy = np.meshgrid(x, y)
r = np.sqrt(xx**2 + yy**2)
lpf = 1 / (1.0 + (r / lp_freq) ** (2 * lp_sigma))
hpf = 1 - (1 / (1.0 + (r / hp_freq) ** (2 * hp_sigma)))
bpf = lpf * hpf
fft_filtered = fft * bpf
filtered.data = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
h = np.hamming(rows)
ham2d = np.sqrt(np.outer(h, h))
filtered.data = filtered.data * ham2d
else:
msg = (
f'Invalid filter method "{method}". Must be one of '
f"{_fmt(_get_lit(self.filter, 'method'))}."
)
raise ValueError(msg)
if taper:
taper_size = np.array(np.array(taper) * self.data.shape[1:], dtype=np.int32)
filtered.data = np.pad(
filtered.data,
[
(0, 0),
(taper_size[0], taper_size[0]),
(taper_size[1], taper_size[1]),
],
mode="constant",
)
return filtered
[docs]
def stack_register( # noqa: PLR0913
self,
method: AlignmentMethodType = AlignmentMethod.PC,
start: Optional[int] = None,
show_progressbar: bool = False,
crop: bool = False,
xrange: Optional[tuple[int, int]] = None,
p: int = 20,
nslices: int = 20,
com_ref_index: Optional[int] = None,
cl_ref_index: Optional[int] = None,
cl_resolution: float = 0.05,
cl_div_factor: int = 8,
cuda: bool = False,
shift_type: Literal["interp", "fourier"] = "fourier",
) -> "TomoStack":
"""
Register stack spatially.
Options are phase correlation (PC) maximization, StackReg, center of
mass ('COM'), or combined center of mass and common line methods.
See docstring for :py:func:`etspy.align.align_stack` for details.
Parameters
----------
method
Algorithm to use for registration calculation. Must be one of
the values specified by the :py:class:`etspy.AlignmentMethod` enum.
start
Position in tilt series to use as starting point for the
alignment. If ``None``, the central projection is used.
show_progressbar
Enable/disable progress bar
crop
If True, crop aligned stack to eliminate border pixels. Default is
False.
xrange
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM`)
The range for performing alignment. See
:py:func:`~etspy.align.calculate_shifts_com` for more details.
p
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM`)
Padding element. See :py:func:`~etspy.align.calculate_shifts_com` for more
details.
nslices
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM`)
Number of slices to return. See
:py:func:`~etspy.align.calculate_shifts_com` for more details.
com_ref_index
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM_CL`)
Reference slice for center of mass alignment. All other slices
will be aligned to this reference. If not provided, the midpoint
of the stack will be chosen. See :py:func:`~etspy.align.calc_shifts_com_cl`
for more details.
cl_ref_index
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM_CL`)
Reference slice for common line alignment. All other slices
will be aligned to this reference. If not provided, the midpoint
of the stack will be chosen.
cl_resolution
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM_CL`)
Resolution for subpixel common line alignment. Default is 0.05.
Should be less than 0.5. See
:py:func:`~etspy.align.calc_shifts_com_cl` for more details.
cl_div_factor
(Only used when ``method ==``:py:attr:`~etspy.AlignmentMethod.COM_CL`)
Factor which determines the number of iterations of common line
alignment to perform. Default is 8. See
:py:func:`~etspy.align.calc_shifts_com_cl` for more details.
cuda
Whether or not to use CUDA-accelerated reconstruction algorithms.
shift_type
Calculated image shifts can be applied using either interpolation via
scipy.ndimage.shift or via Fourier shift as implemented in
scipy.ndimage.fourier_shift. Must be either 'interp' or 'fourier'.
Returns
-------
out : TomoStack
Spatially registered copy of the input stack
Examples
--------
Registration with phase correlation algorithm (PC)
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> regPC = stack.stack_register('PC')
Registration with center of mass tracking (COM)
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> regCOM = stack.stack_register('COM')
Registration with StackReg
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> regSR = stack.stack_register('StackReg')
Registration with center of mass and common line (COM-CL)
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> regCOMCL = stack.stack_register('COM-CL')
"""
if AlignmentMethod.is_valid_value(method):
out = align.align_stack(
self,
method,
start,
show_progressbar,
xrange=xrange,
p=p,
nslices=nslices,
com_ref_index=com_ref_index,
cl_ref_index=cl_ref_index,
cl_resolution=cl_resolution,
cl_div_factor=cl_div_factor,
cuda=cuda,
shift_type=shift_type,
)
else:
msg = (
f'Invalid registration method "{method}". '
f"Must be one of {_fmt(AlignmentMethod.values())}."
)
raise TypeError(msg)
if crop:
out = align.shift_crop(out)
return out
[docs]
def tilt_align(
self,
method: Literal["CoM", "MaxImage"],
slices: Optional[np.ndarray] = None,
nslices: Optional[int] = None,
limit: float = 10,
delta: float = 0.1,
plot_results: bool = False,
also_shift: bool = False,
shift_limit: int = 20,
):
"""
Align the tilt axis of a TomoStack.
Uses either a center-of-mass approach or a maximum image approach
Available methods are ``'CoM'`` and ``'MaxImage'``:
**CoM:**
Track the center of mass (CoM) of the projections at three
locations. Fit the motion of the CoM as a function of tilt to that
expected for an ideal cylinder to calculate an X-shift at each
location. Perform a linear fit of the three X-shifts to calculate an
ideal rotation.
**MaxImage:**
Perform automated determination of the tilt axis of a
TomoStack by analyzing features in the projected maximum image. A combination
of edge detection and Hough transform analysis is used to determine the global
rotation of the stack. Optionally, the global shift of the tilt axis can also
be calculated by minimization of the sum of the reconstruction.
Parameters
----------
method
Algorithm to use for registration alignment. Must be either ``'CoM'`` or
``'MaxImage'``.
slices
(Only used when ``method == "CoM"``)
Locations at which to perform the Center of Mass analysis. If not
provided, an appropriate list of slices will be automatically determined.
nslices
(Only used when ``method == "CoM"``)
Nubmer of slices to use for the center of mass analysis (only used if the
``slices`` parameter is not specified). If ``None``, a value of 10% of the
x-axis size will be used, clamped to the range [3, 50], as calculated in
the :py:func:`~etspy.align.tilt_com` function.
limit
(Only used when ``method == "MaxImage"``)
Maximum rotation angle for MaxImage calculation
delta
(Only used when ``method == "MaxImage"``)
Angular increment in degrees for MaxImage calculation
plot_results
(Only used when ``method == "MaxImage"``)
If ``True``, plot the maximum image along with the lines determined
by Hough analysis
also_shift
(Only used when ``method == "MaxImage"``)
If ``True``, also calculate and apply the global shift perpendicular to the
tilt by minimizing the sum of the reconstruction
shift_limit
(Only used when ``method == "MaxImage"``)
The limit of shifts applied if ``also_shift`` is set to ``True``
Returns
-------
out : TomoStack
Copy of the input stack rotated by calculated angle
Examples
--------
Align tilt axis using the center of mass (CoM) method:
>>> import etspy.datasets as ds
>>> import numpy as np
>>> stack = ds.get_needle_data()
>>> reg = stack.stack_register('PC', show_progressbar=False)
>>> method = 'CoM'
>>> ali = reg.tilt_align(method, slices=np.array([50,100,160]))
Align tilt axis using the maximum image method:
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data()
>>> reg = stack.stack_register('PC', show_progressbar=False)
>>> method = 'MaxImage'
>>> ali = reg.tilt_align(method)
"""
if method == "CoM":
out = align.tilt_com(self, slices, nslices)
elif method == "MaxImage":
out = align.tilt_maximage(
self,
limit,
delta,
plot_results,
also_shift,
shift_limit,
)
else:
msg = (
f'Invalid alignment method "{method}". Must be one of '
f"{_fmt(_get_lit(self.tilt_align, 'method'))}."
)
raise ValueError(msg)
return out
[docs]
def reconstruct( # noqa: PLR0913
self,
method: ReconMethodType = "FBP",
iterations: int = 5,
constrain: bool = False,
thresh: float = 0,
cuda: Optional[bool] = None,
thickness: Optional[int] = None,
show_progressbar: bool = True,
p: float = 0.99,
ncores: Optional[int] = None,
sino_filter: FbpMethodType = "shepp-logan",
dart_iterations: Optional[int] = 5,
gray_levels: Optional[Union[list, np.ndarray]] = None,
) -> "RecStack":
"""
Reconstruct a TomoStack series using one of the available methods.
Parameters
----------
method
Reconstruction algorithm to use. Must be one of ``"FBP"`` (default),
``"SIRT"``, ``"SART"``, or ``"DART"``
iterations
Number of iterations for the SIRT reconstruction (used with ``SIRT``,
``SART``, and ``DART`` methods) (default: 5)
constrain
If ``True``, output reconstruction is constrained above value given
by ``thresh``
thresh
Value above which to constrain the reconstructed data
cuda
Whether or not to use CUDA-accelerated reconstruction algorithms. If
``None`` (the default), the decision to use CUDA will be left to
:py:func:`astra.astra.use_cuda`.
thickness
Size of the output volume (in pixels) in the projection direction. If
``None``, the y-size of the stack is used.
show_progressbar
If ``True``, show a progress bar for the reconstruction. Default is
``True``.
p
Probability for setting free pixels in DART reconstruction (only used
if the reconstruction method is DART, default: 0.99)
ncores
Number of cores to use for multithreaded reconstructions.
sino_filter
Filter for filtered backprojection. Default is ``"shepp-logan"``.
Available options are detailed in the Astra Toolbox documentation
under the ``cfg.FilterType`` option of
:external+astra:doc:`docs/algs/FBP_CUDA`.
dart_iterations
Number of iterations to employ for DART reconstruction
gray_levels
List of gray levels to use for DART reconstruction
Returns
-------
rec : RecStack
RecStack containing the reconstructed volume
Examples
--------
Filtered backprojection (FBP) reconstruction:
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data(aligned=True)
>>> slices = stack.isig[:, 120:121].deepcopy()
>>> rec = slices.reconstruct('FBP', cuda=False, show_progressbar=False)
Simultaneous iterative reconstruction technique (SIRT) reconstruction:
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data(aligned=True)
>>> slices = stack.isig[:, 120:121].deepcopy()
>>> rec = slices.reconstruct('SIRT',iterations=5,
... cuda=False, show_progressbar=False)
SIRT reconstruction with positivity constraint:
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data(aligned=True)
>>> slices = stack.isig[:, 120:121].deepcopy()
>>> iterations = 5
>>> constrain = True
>>> thresh = 0
>>> rec = slices.reconstruct('SIRT', iterations, constrain, thresh,
... cuda=False, show_progressbar=False)
Discreate algebraice reconstruction technique (DART) reconstruction:
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data(aligned=True)
>>> slices = stack.isig[:, 120:121].deepcopy()
>>> gray_levels = [0., slices.data.max()/2, slices.data.max()]
>>> rec = slices.reconstruct('DART', iterations=5, cuda=False,
... gray_levels=gray_levels, p=0.99,
... dart_iterations=5, show_progressbar=False)
"""
if method.lower() not in [
"fbp",
"sirt",
"sart",
"dart",
]:
msg = (
f'Invalid reconstruction algorithm "{method}". Must be one of '
f"{_fmt(_get_lit(self.reconstruct, 'method'))}."
)
raise ValueError(msg)
if np.all(self.tilts.data == 0):
msg = (
"Tilts are not defined in stack.tilts (values were all zeros). "
"Please set tilt values before alignment."
)
raise ValueError(msg)
if cuda is None:
if astra.use_cuda(): # coverage: nocuda
logger.info("CUDA detected with Astra")
cuda = True
else:
logger.info("CUDA not detected with Astra")
cuda = False
if method.lower() == "dart":
if gray_levels is None:
msg = "gray_levels must be provided for DART"
raise ValueError(msg)
if not isinstance(gray_levels, (np.ndarray, list)):
msg = f"Unknown type ({type(gray_levels)}) for gray_levels"
raise ValueError(msg)
if dart_iterations is None:
logger.info("Using default number of DART iterations (5)")
dart_iterations = 5
else:
dart_iterations = None
gray_levels = None
rec = recon.run(
stack=self.data,
tilts=self.tilts.data,
method=method,
niterations=iterations,
constrain=constrain,
thresh=thresh,
cuda=cuda,
thickness=thickness,
ncores=ncores,
bp_filter=sino_filter,
gray_levels=gray_levels,
dart_iterations=cast("int", dart_iterations),
p=p,
show_progressbar=show_progressbar,
)
axes_dict = self.axes_manager.as_dictionary()
rec_axes_dict = [
axes_dict["axis-2"],
dict(axes_dict["axis-1"]),
axes_dict["axis-1"],
]
rec_axes_dict[1]["name"] = "z"
rec_axes_dict[1]["size"] = rec.shape[1]
rec = RecStack(rec, axes=rec_axes_dict)
return rec
[docs]
def test_align( # noqa: PLR0913
self,
tilt_shift: float = 0.0,
tilt_rotation: float = 0.0,
slices: Optional[np.ndarray] = None,
thickness: Optional[int] = None,
method: Literal["FBP", "SIRT", "SART"] = "FBP",
iterations: int = 50,
constrain: bool = True,
cuda: Optional[bool] = None,
thresh: float = 0,
vmin_std: float = 0.1,
vmax_std: float = 10,
) -> Figure:
"""
Perform a reconstruction with limited slices for visual inspection.
This method is useful to quickly test the alignment of the stack prior
to a full reconstruction attempt.
Parameters
----------
tilt_shift
Number of pixels by which to shift the stack prior to reconstruction
tilt_rotation
Angle by which to rotate stack prior to reconstruction
slices
Position of slices to use for the reconstruction. If ``None`` (default),
three positions at 1/4, 1/2, and 3/4 of the full size of the stack are
automatically chosen.
thickness
Size of the output volume (in pixels) in the projection direction.
If ``None`` (default), the y-size of the stack is used.
method
Reconstruction algorithm to use. Must be one of ``"FBP"`` (default),
``"SIRT"``, or ``"SART"`` (inapplicable for the ``"DART"`` algorithm)
iterations
Number of iterations for the SIRT reconstruction (used with ``SIRT``,
and ``SART`` methods) (default: 50)
cuda
Whether or not to use CUDA-accelerated reconstruction algorithms. If
``None`` (the default), the decision to use CUDA will be left to
:py:func:`astra.astra.use_cuda`.
thresh
Value above which to constrain the reconstructed data
vmin_std
Number of standard deviations from mean (lower bound) to use for scaling the
displayed slices
vmax_std
Number of standard deviations from mean (upper bound) to use for scaling the
displayed slices
Returns
-------
fig : :py:class:`~matplotlib.figure.Figure`
"""
if slices is None:
mid = np.array(self.data.shape[2] / 2, dtype=np.int32)
slices = np.array([mid / 2, mid, mid + mid / 2], dtype=np.int32)
if (tilt_shift != 0.0) or (tilt_rotation != 0.0):
shifted = self.trans_stack(xshift=0, yshift=tilt_shift, angle=tilt_rotation)
else:
shifted = self.deepcopy()
shifted = cast("TomoStack", shifted)
shifted.data = shifted.data[:, :, slices]
cast("Uda", shifted.axes_manager[0]).axis = cast(
"Uda",
self.axes_manager[0],
).axis
if cuda is None:
if astra.use_cuda(): # coverage: nocuda
logger.info("CUDA detected with Astra")
cuda = True
else:
cuda = False
logger.info("CUDA not detected with Astra")
rec = shifted.reconstruct(
method=method,
iterations=iterations,
constrain=constrain,
thickness=thickness,
cuda=cuda,
thresh=thresh,
show_progressbar=False,
)
if "ipympl" in mpl.get_backend().lower():
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(7, 3))
elif "nbagg" in mpl.get_backend().lower():
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 4))
else:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
minvals = rec.data.mean((1, 2)) - vmin_std * rec.data.std((1, 2))
minvals[minvals < 0] = 0
maxvals = rec.data.mean((1, 2)) + vmax_std * rec.data.std((1, 2))
ax1.imshow(rec.data[0, :, :], cmap="afmhot", vmin=minvals[0], vmax=maxvals[0])
ax1.set_title(f"Slice {slices[0]}")
ax1.set_axis_off()
ax2.imshow(rec.data[1, :, :], cmap="afmhot", vmin=minvals[1], vmax=maxvals[1])
ax2.set_title(f"Slice {slices[1]}")
ax2.set_axis_off()
ax3.imshow(rec.data[2, :, :], cmap="afmhot", vmin=minvals[2], vmax=maxvals[2])
ax3.set_title(f"Slice {slices[2]}")
ax3.set_axis_off()
fig.tight_layout()
return fig
[docs]
def set_tilts(self, start: float, increment: float):
"""
Calibrate the tilt axis of the image stack.
Parameters
----------
start
Tilt angle of first image in stack
increment
Tilt increment between images
"""
nimages = self.data.shape[0]
ax = cast("Uda", self.axes_manager[0])
ax.name = "Projections"
ax.units = "degrees"
ax.scale = increment
ax.offset = start
tilts = np.arange(start, nimages * increment + start, increment)
if not self.metadata.has_item("Tomography"):
self.metadata.add_node("Tomography")
tomo_meta = cast("Dtb", self.metadata.Tomography)
tomo_meta.set_item("tiltaxis", 0)
tomo_meta.set_item("xshift", 0)
tomo_meta.set_item("yshift", 0)
tomo_meta.set_item("cropped", value=False)
self.tilts = tilts
[docs]
def manual_align( # noqa: PLR0915
self,
nslice: int,
xshift: int = 0,
yshift: int = 0,
display: bool = False,
) -> "TomoStack":
"""
Manually shift part of a stack with respect to another and return it as a copy.
Parameters
----------
nslice
Slice position at which to implement shift
xshift
Number of pixels with which to shift the second portion of the
stack relative to the first in the X dimension.
yshift
Number of pixels with which to shift the second portion of the
stack relative to the first in the Y dimension.
display
If True, display the result.
"""
output = self.deepcopy()
if yshift == 0:
if xshift > 0: # x+ , y0
output.data = output.data[:, :, :-xshift]
output.data[0:nslice, :, :] = self.data[0:nslice, :, xshift:]
output.data[nslice:, :, :] = self.data[nslice:, :, :-xshift]
elif xshift < 0: # x- , y0
output.data = output.data[:, :, :xshift]
output.data[0:nslice, :, :] = self.data[0:nslice, :, :xshift]
output.data[nslice:, :, :] = self.data[nslice:, :, -xshift:]
elif xshift == 0:
if yshift > 0: # x0 , y+
output.data = output.data[:, :-yshift, :]
output.data[0:nslice, :, :] = self.data[0:nslice, yshift:, :]
output.data[nslice:, :, :] = self.data[nslice:, :-yshift, :]
elif yshift < 0: # x0 , y-
output.data = output.data[:, :yshift, :]
output.data[0:nslice, :, :] = self.data[0:nslice, :yshift, :]
output.data[nslice:, :, :] = self.data[nslice:, -yshift:, :]
elif (xshift > 0) and (yshift > 0): # x+ , y+
output.data = output.data[:, :-yshift, :-xshift]
output.data[0:nslice, :, :] = self.data[0:nslice, yshift:, xshift:]
output.data[nslice:, :, :] = self.data[nslice:, :-yshift, :-xshift]
elif (xshift > 0) and (yshift < 0): # x+ , y-
output.data = output.data[:, :yshift, :-xshift]
output.data[0:nslice, :, :] = self.data[0:nslice, :yshift, xshift:]
output.data[nslice:, :, :] = self.data[nslice:, -yshift:, :-xshift]
elif (xshift < 0) and (yshift > 0): # x- , y +
output.data = output.data[:, :-yshift, :xshift]
output.data[0:nslice, :, :] = self.data[0:nslice, yshift:, :xshift]
output.data[nslice:, :, :] = self.data[nslice:, :-yshift, -xshift:]
elif (xshift < 0) and (yshift < 0): # x- , x-
output.data = output.data[:, :yshift, :xshift]
output.data[0:nslice, :, :] = self.data[0:nslice, :yshift, :xshift]
output.data[nslice:, :, :] = self.data[nslice:, -yshift:, -xshift:]
if display:
old_im1 = self.data[nslice - 1, :, :]
old_im2 = self.data[nslice, :, :]
new_im1 = output.data[nslice - 1, :, :]
new_im2 = output.data[nslice, :, :]
old_im1 = old_im1 - old_im1.min()
old_im1 = old_im1 / old_im1.max()
old_im2 = old_im2 - old_im2.min()
old_im2 = old_im2 / old_im2.max()
new_im1 = new_im1 - new_im1.min()
new_im1 = new_im1 / new_im1.max()
new_im2 = new_im2 - new_im2.min()
new_im2 = new_im2 / new_im2.max()
fig, ax = plt.subplots(2, 3)
ax[0, 0].imshow(old_im1)
ax[0, 1].imshow(old_im2)
ax[0, 2].imshow(old_im1 - old_im2, clim=[-0.5, 0.5])
ax[1, 0].imshow(new_im1)
ax[1, 1].imshow(new_im2)
ax[1, 2].imshow(new_im1 - new_im2, clim=[-0.5, 0.5])
return output
[docs]
def recon_error(
self,
nslice: Optional[int] = None,
algorithm: Literal["SIRT", "SART"] = "SIRT",
iterations: int = 50,
constrain: bool = True,
cuda: Optional[bool] = None,
thresh: float = 0,
) -> tuple[Signal2D, Signal1D]:
"""
Determine the optimum number of iterations for reconstruction.
Evaluates the difference between reconstruction and input data
at each iteration and terminates when the change between iterations is
below tolerance.
Parameters
----------
nslice
Slice location at which to perform the evaluation.
algorithm
Reconstruction algorithm use. Must be either ``'SIRT'`` (default)
or ``'SART'`` (this method is inapplicable for ``'FBP'`` and ``'DART'``).
constrain
If True, perform SIRT reconstruction with a non-negativity
constraint. Default is ``True``
cuda
Whether or not to use CUDA-accelerated reconstruction algorithms. If
``None`` (the default), the decision to use CUDA will be left to
:py:func:`astra.astra.use_cuda`.
thresh
Value above which to constrain the reconstructed data
Returns
-------
rec_stack : :py:class:`~hyperspy.api.signals.Signal2D`
Signal containing the SIRT reconstruction at each iteration
for visual inspection.
error : :py:class:`~hyperspy.api.signals.Signal1D`
Sum of squared difference between the forward-projected
reconstruction and the input sinogram at each iteration
Examples
--------
>>> import etspy.datasets as ds
>>> stack = ds.get_needle_data(aligned=True)
>>> rec_stack, error = stack.recon_error(iterations=5)
"""
if np.all(self.tilts.data == 0):
msg = "Tilt angles not defined"
raise ValueError(msg)
if not nslice:
nslice = int(self.data.shape[2] / 2)
if cuda is None:
if astra.use_cuda(): # coverage: nocuda
logger.info("CUDA detected with Astra")
cuda = True
else:
cuda = False
logger.info("CUDA not detected with Astra")
sinogram = self.isig[nslice : nslice + 1, :].data.squeeze()
rec_stack, error = recon.astra_error(
sinogram,
angles=self.tilts.data,
method=algorithm,
iterations=iterations,
constrain=constrain,
thresh=thresh,
cuda=cuda,
)
rec_stack = Signal2D(rec_stack)
rec_ax0, rec_ax1, rec_ax2 = (
cast("Uda", rec_stack.axes_manager[i]) for i in range(3)
)
self_ax2 = cast("Uda", self.axes_manager[2])
rec_ax0.name = algorithm.upper() + " iteration"
rec_ax0.scale = 1
rec_ax1.name = self_ax2.name
rec_ax1.scale = self_ax2.scale
rec_ax1.units = self_ax2.units
rec_ax2.name = "z"
rec_ax2.scale = self_ax2.scale
rec_ax2.units = self_ax2.units
rec_stack.navigator = "slider"
error = Signal1D(error)
cast("Uda", error.axes_manager[0]).name = algorithm.upper() + " Iteration"
cast("Dtb", error.metadata.Signal).quantity = "Sum of Squared Difference"
return rec_stack, error
[docs]
class RecStack(CommonStack):
"""
Create a RecStack instance, used to hold the results of a reconstructed volume.
All arguments are passed to the :py:class:`~hyperspy.api.signals.Signal2D`
constructor and should be used as documented for that method.
Group
-----
signals
Order
-----
2
"""
_signal_type = "RecStack"
_signal_dimension = 2
[docs]
def __init__(self, *args, **kwargs):
"""
Create a RecStack signal.
Parameters
----------
args
Additional non-keyword arguments passed to
:py:class:`~hyperspy.api.signals.Signal2D`
kwargs
Additional keyword arguments passed to
:py:class:`~hyperspy.api.signals.Signal2D`
"""
super().__init__(*args, **kwargs)
if self.axes_manager.navigation_dimension not in (0, 1):
msg = (
"A RecStack must have a singular (or no) navigation axis. "
f"Navigation shape was: {self.axes_manager.navigation_shape}"
)
raise ValueError(msg)
self.inav = _RecStackSlicer(self, isNavigation=True)
self.isig = _RecStackSlicer(self, isNavigation=False)
[docs]
def forward_project(
self,
tilts: Optional[Union[TomoTilts, np.ndarray]],
cuda: Optional[bool] = None,
) -> TomoStack:
"""
Forward project the RecStack signal.
Parameters
----------
tilts
Tilt angles in degrees to use for the projection
:py:class:`~numpy.ndarray`
cuda
Whether or not to use CUDA-accelerated reconstruction algorithms. If
``None`` (the default), the decision to use CUDA will be left to
:py:func:`astra.astra.use_cuda`.
"""
if cuda is None:
if astra.use_cuda(): # coverage: nocuda
logger.info("CUDA detected with Astra")
cuda = True
else:
cuda = False
logger.info("CUDA not detected with Astra")
if len(self.data.shape) == 2: # noqa: PLR2004
nz, ny = self.data.shape
ntilts = len(tilts)
thetas = np.deg2rad(tilts)
vol_geom = astra.create_vol_geom(nz, ny)
proj_geom = astra.create_proj_geom("parallel", 1.0, ny, thetas)
if cuda:
proj_id = astra.create_projector("cuda", proj_geom, vol_geom)
else:
proj_id = astra.create_projector("linear", proj_geom, vol_geom)
proj_id, proj = astra.create_sino(self.data, proj_id)
proj = proj[:, :, np.newaxis]
else:
nx, nz, ny = self.data.shape
ntilts = len(tilts)
proj = np.zeros([ntilts, ny, nx])
thetas = np.deg2rad(tilts)
vol_geom = astra.create_vol_geom(ny, ny)
proj_geom = astra.create_proj_geom("parallel", 1.0, ny, thetas)
if cuda:
proj_id = astra.create_projector("cuda", proj_geom, vol_geom)
else:
proj_id = astra.create_projector("linear", proj_geom, vol_geom)
for i in range(nx):
_, proj[:, :, i] = astra.create_sino(self.data[i, :, :], proj_id)
astra.clear()
return TomoStack(proj, tilts)
[docs]
def plot_slices(
self,
xslice: Optional[int] = None,
yslice: Optional[int] = None,
zslice: Optional[int] = None,
vmin_std: float = 0.1,
vmax_std: float = 5,
):
"""
Plot slices along all three axes of a reconstruction stack.
Parameters
----------
xslice, yslice, zslice
Indices of slices to plot. If ``None`` (default), the middle
most slice will be used.
vmin_std, vmax_std
Number of standard deviations from mean to use for
scaling the displayed slices
Returns
-------
fig : ~matplotlib.figure.Figure
The figure containing a view of the three slices
"""
if xslice is None:
xslice = self.data.shape[0] // 2
if yslice is None:
yslice = self.data.shape[1] // 2
if zslice is None:
zslice = self.data.shape[2] // 2
if "ipympl" in mpl.get_backend().lower():
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(7, 3))
elif "nbagg" in mpl.get_backend().lower():
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 4))
else:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
slices = [
self.data[xslice, :, :],
self.data[:, zslice, :],
self.data[:, :, yslice],
]
minvals = [slices[i].mean() - vmin_std * slices[i].std() for i in range(3)]
minvals = [x if x >= 0 else 0 for x in minvals]
maxvals = [slices[i].mean() + vmax_std * slices[i].std() for i in range(3)]
ax1.imshow(slices[0], cmap="afmhot", vmin=minvals[0], vmax=maxvals[0])
ax1.set_title(f"Z-Y Slice {xslice}")
ax1.set_ylabel("Z")
ax1.set_xlabel("Y")
ax2.imshow(slices[1], cmap="afmhot", vmin=minvals[1], vmax=maxvals[1])
ax2.set_title(f"Y-X Slice {zslice}")
ax2.set_ylabel("Y")
ax2.set_xlabel("X")
ax3.imshow(slices[2].T, cmap="afmhot", vmin=minvals[2], vmax=maxvals[2])
ax3.set_title(f"Z-X Slice {yslice}")
ax3.set_ylabel("Z")
ax3.set_xlabel("X")
fig.tight_layout()
[i.set_xticks([]) for i in [ax1, ax2, ax3]]
[i.set_yticks([]) for i in [ax1, ax2, ax3]]
return fig