"""Utility module for ETSpy package."""
import logging
from multiprocessing import Pool
from typing import Literal, Optional, cast
import numpy as np
import tqdm
from hyperspy._signals.signal2d import Signal2D
from hyperspy.axes import UniformDataAxis as Uda
from pystackreg import StackReg
from scipy import ndimage
from etspy import _format_choices as _fmt
from etspy import _get_literal_hint_values as _get_lit
from etspy.align import calculate_shifts_stackreg
from etspy.base import TomoStack
[docs]
def multiaverage(stack: np.ndarray, nframes: int, ny: int, nx: int) -> np.ndarray:
"""
Register a multi-frame series collected by SerialEM.
Parameters
----------
stack
Array of shape [nframes, ny, nx].
nframes
Number of frames per tilt.
ny
Pixels in y-dimension.
nx
Pixels in x-dimension.
Returns
-------
average : :py:class:`~numpy.ndarray`
Average of all frames at given tilt
Group
-----
utilities
"""
def _calc_sr_shifts(stack):
sr = StackReg(StackReg.TRANSLATION)
shifts = sr.register_stack(stack, reference="previous")
shifts = -np.array([i[0:2, 2][::-1] for i in shifts])
return shifts
shifted = np.zeros([nframes, ny, nx])
shifts = _calc_sr_shifts(stack)
for k in range(nframes):
shifted[k, :, :] = ndimage.shift(
stack[k, :, :],
shift=[shifts[k, 0], shifts[k, 1]],
)
average = shifted.mean(0)
return average
[docs]
def register_serialem_stack(stack: Signal2D, ncpus: int = 1) -> TomoStack:
"""
Register a multi-frame series collected by SerialEM.
Parameters
----------
stack
Signal of shape [ntilts, nframes, ny, nx].
Returns
-------
reg : TomoStack
Result of aligning and averaging frames at each tilt with shape [ntilts, ny, nx]
Group
-----
utilities
"""
align_logger = logging.getLogger("etspy.align")
log_level = align_logger.getEffectiveLevel()
align_logger.setLevel(logging.ERROR)
ntilts, nframes, ny, nx = stack.data.shape
if ncpus == 1:
reg = np.zeros([ntilts, ny, nx], stack.data.dtype)
for i in tqdm.tqdm(range(ntilts)):
shifted = np.zeros([nframes, ny, nx])
shifts = calculate_shifts_stackreg(
stack.inav[:, i],
start=None,
show_progressbar=False,
)
for k in range(nframes):
shifted[k, :, :] = ndimage.shift(
stack.data[i, k, :, :],
shift=[shifts[k, 0], shifts[k, 1]],
)
reg[i, :, :] = shifted.mean(0)
else:
with Pool(ncpus) as pool:
reg = pool.starmap(
multiaverage,
[(stack.inav[:, i].data, nframes, ny, nx) for i in range(ntilts)],
)
reg = np.array(reg)
reg = TomoStack(reg)
reg_ax_0, reg_ax_1, reg_ax_2 = (cast(Uda, reg.axes_manager[i]) for i in range(3))
stack_ax_1, stack_ax_2, stack_ax_3 = (
cast(Uda, stack.axes_manager[i]) for i in range(1, 4)
)
reg_ax_0.scale = stack_ax_1.scale
reg_ax_0.offset = stack_ax_1.offset
reg_ax_0.units = stack_ax_1.units
reg_ax_1.scale = stack_ax_2.scale
reg_ax_1.offset = stack_ax_2.offset
reg_ax_1.units = stack_ax_2.units
reg_ax_2.scale = stack_ax_3.scale
reg_ax_2.offset = stack_ax_3.offset
reg_ax_2.units = stack_ax_3.units
if stack.metadata.has_item("Acquisition_instrument"):
reg.metadata.Acquisition_instrument = stack.metadata.Acquisition_instrument
if stack.metadata.has_item("Tomography"):
reg.metadata.Tomography = stack.metadata.Tomography
align_logger.setLevel(log_level)
return reg
[docs]
def weight_stack(
stack: TomoStack,
accuracy: Literal["low", "medium", "high"] = "medium",
) -> TomoStack:
"""
Apply a weighting window to a stack perpendicular to the tilt axis.
This weighting is useful for reducing the effects of mass introduced at the
edges of as stack when determining alignments based on the center of mass.
As described in:
T. Sanders. Physically motivated global alignment method for electron
tomography, Advanced Structural and Chemical Imaging vol. 1 (2015) pp 1-11.
https://doi.org/10.1186/s40679-015-0005-7
Parameters
----------
stack
The stack to be weighted.
accuracy
A string indicating the accuracy level for weighting. Options are:
'low', 'medium' (default), or 'high'.
Returns
-------
stackw : TomoStack
The weighted version of the input stack.
Group
-----
utilities
"""
# Set the parameters based on the accuracy input
# with default of "medium"
niterations = 2000
delta = 0.01
if accuracy.lower():
if accuracy == "low":
niterations = 800
delta = 0.025
elif accuracy == "medium":
pass
elif accuracy == "high":
niterations = 20000
delta = 0.001
else:
msg = (
f'Invalid accuracy level "{accuracy}". Must be one of '
f"{_fmt(_get_lit(weight_stack, 'accuracy'))}."
)
raise ValueError(msg)
weighted_stack = stack.deepcopy()
# Get stack dimensions
ntilts, ny, nx = weighted_stack.data.shape
# Compute the minimum total projected mass and the corresponding
# slice index (min_slice)
min_mass, min_slice = (
np.min(
np.sum(np.sum(weighted_stack.data, axis=2), axis=1),
),
np.argmin(np.sum(np.sum(weighted_stack.data, axis=2), axis=1)),
)
# Initialize the window array
window = np.zeros([ny, nx])
# Initialize the status vector (1 means unmarked, 0 means marked) and mark
# the reference slice (min_slice)
status = np.ones(ntilts)
status[min_slice] = 0
# Generate the weighting profile `r` based on a non-linear cosine function
r = np.arange(ny)
r = 2 / (ny - 1) * r - 1
r = np.cos(np.pi * r**2) / 2 + 0.5
# Initialize adjustment factors for each slice
adjustments = np.zeros(ntilts)
# Coarse adjustment loop
# In this step, the applied window is made increasingly restrictive in 10 pixel
# increments. Whenever the the windowed mass of a projection drops below the value
# of min_alpha, that projection is marked and the window restriction is not carried
# any further for that projection.
power = 10 # initialize power
for power in np.linspace(10, niterations, niterations // 10):
# Compute the power-weighted profile for the current iteration
r_power = r ** (power * delta)
window = r_power[:, np.newaxis] # Broadcasting across all columns
# Compute the weighted sum for all slices at once using vectorization
weighted_mass = np.sum(
weighted_stack.data * window[np.newaxis, :, :],
axis=(1, 2),
)
# Update the status and adjustments for slices with weighted sums below min_mass
update_mask = (status != 0) & (weighted_mass < min_mass)
status[update_mask] = 0
adjustments[update_mask] = power - 10
# Break early if all slices are marked
if not np.any(status): # More efficient than np.sum(status)
break
# Set window for any unmarked slices to the most restricive used
# in the rest of the slices
adjustments[np.where(status != 0)] = power - 10
# Fine adjustment loop
# In this step the severity of the window is calculated again using the value
# calculated in the coarse step and the window is made more restrictive in 1
# pixel increments.
status = np.ones(ntilts)
status[min_slice] = 0
for j in range(ntilts):
if j != min_slice:
for power in np.linspace(1, 10, 10):
# Apply fine adjustments to the weight profile and
# update the weight grid
r_power = r ** ((power + adjustments[j]) * delta)
window[:] = r_power[:, np.newaxis]
if np.sum(weighted_stack.data[j, :, :] * window) < min_mass:
adjustments[j] = (power - 1) + adjustments[j]
status[j] = 0
break
# Restrict the window of any unmarked projections
adjustments[status != 0] += 10
# Apply the final window to the entire stack
for i in range(ntilts):
window[:] = (r ** (adjustments[i] * delta))[:, np.newaxis]
weighted_stack.data[i, :, :] *= window
return weighted_stack
[docs]
def calc_est_angles(num_points: int) -> np.ndarray:
"""
Caculate angles used for equally sloped tomography (EST).
See:
J. Miao, F. Forster, and O. Levi. Equally sloped tomography with
oversampling reconstruction. Phys. Rev. B, 72 (2005) 052103.
https://doi.org/10.1103/PhysRevB.72.052103
Parameters
----------
num_points
Number of points in scan.
Returns
-------
angles : :py:class:`~numpy.ndarray`
Angles in degrees for equally sloped tomography.
Group
-----
utilities
"""
if np.mod(num_points, 2) != 0:
msg = "N must be an even number"
raise ValueError(msg)
angles = np.zeros(2 * num_points)
n = np.arange(num_points / 2 + 1, num_points + 1, dtype="int")
theta1 = -np.arctan((num_points + 2 - 2 * n) / num_points)
theta1 = np.pi / 2 - theta1
n = np.arange(1, num_points + 1, dtype="int")
theta2 = np.arctan((num_points + 2 - 2 * n) / num_points)
n = np.arange(1, num_points / 2 + 1, dtype="int")
theta3 = -np.pi / 2 + np.arctan((num_points + 2 - 2 * n) / num_points)
angles = np.concatenate([theta1, theta2, theta3], axis=0)
angles = angles * 180 / np.pi
angles.sort()
return angles
[docs]
def calc_golden_ratio_angles(tilt_range: int, nangles: int) -> np.ndarray:
"""
Calculate golden ratio angles for a given tilt range.
See:
A. P. Kaestner, B. Munch and P. Trtik, Opt. Eng., 2011, 50, 123201.
https://doi.org/10.1117/1.3660298
Parameters
----------
tilt_range
Tilt range in degrees.
nangles
Number of angles to calculate.
Returns
-------
thetas : :py:class:`~numpy.ndarray`
Angles in degrees for golden ratio sampling over the provided tilt range.
Group
-----
utilities
"""
alpha = tilt_range / 180 * np.pi
i = np.arange(nangles) + 1
thetas = np.mod(i * alpha * ((1 + np.sqrt(5)) / 2), alpha) - alpha / 2
thetas = thetas * 180 / np.pi
return thetas
[docs]
def get_radial_mask(
mask_shape: tuple[int, int],
center: Optional[tuple[int, int]] = None,
) -> np.ndarray:
"""
Calculate a radial mask given a shape and center position.
Parameters
----------
mask_shape
Shape (rows, cols) of the resulting mask.
center
(x, y) location of mask center, optional. If ``None``, the center of
the ``mask_shape`` will be used by default.
Returns
-------
mask : :py:class:`~numpy.ndarray`
Logical array that is True in the masked region and False outside of it.
Group
-----
utilities
"""
if center is None:
center = cast(tuple[int, int], tuple(int(i / 2) for i in mask_shape))
radius = min(
center[0],
center[1],
mask_shape[1] - center[0],
mask_shape[0] - center[1],
)
yy, xx = np.ogrid[0 : mask_shape[0], 0 : mask_shape[1]]
mask = np.sqrt((xx - center[0]) ** 2 + (yy - center[1]) ** 2)
mask = mask < radius
return mask
[docs]
def filter_stack(
stack: TomoStack,
filter_name: Literal[
"ram-lak",
"shepp-logan",
"hanning",
"hann",
"cosine",
"cos",
] = "shepp-logan",
cutoff: float = 0.5,
) -> TomoStack:
"""
Apply a Fourier filter to a sinogram or series of sinograms.
Parameters
----------
stack
TomoStack with projection data
filter_name
Type of filter to apply.
cutoff
Factor of sampling rate to use as the cutoff. Default is 0.5 which
corresponds to the Nyquist frequency.
Returns
-------
result : TomoStack
Filtered version of the input TomoStack.
Group
-----
utilities
"""
_, ny = stack.data.shape[0:2]
filter_length = max(64, 2 ** (int(np.ceil(np.log2(2 * ny)))))
freq_indices = np.arange(filter_length // 2 + 1)
ffilter = np.linspace(
cutoff / filter_length,
1 - cutoff / filter_length,
len(freq_indices),
)
omega = 2 * np.pi * freq_indices / filter_length
if filter_name == "ram-lak":
pass
elif filter_name == "shepp-logan":
ffilter[1:] = ffilter[1:] * np.sinc(omega[1:] / (2 * np.pi))
elif filter_name in [
"hanning",
"hann",
]:
ffilter[1:] = ffilter[1:] * (1 + np.cos(omega[1:])) / 2
elif filter_name in [
"cosine",
"cos",
]:
ffilter[1:] = ffilter[1:] * np.cos(omega[1:] / 2)
else:
msg = (
f'Invalid filter type "{filter_name}". Must be one of '
f"{_fmt(_get_lit(filter_stack, 'filter_name'))}."
)
raise ValueError(msg)
ffilter = np.concatenate((ffilter, ffilter[-2:0:-1]))
nfilter = ffilter.shape[0]
pad_length = int((nfilter - ny) / 2)
if len(stack.data.shape) == 2: # noqa: PLR2004
padded = np.pad(stack.data, [[0, 0], [pad_length, pad_length]])
proj_fft = np.fft.fft(padded, axis=1)
filtered = np.fft.ifft(proj_fft * ffilter, axis=1).real
filtered = filtered[:, pad_length:-pad_length]
elif len(stack.data.shape) == 3: # noqa: PLR2004
padded = np.pad(stack.data, [[0, 0], [pad_length, pad_length], [0, 0]])
proj_fft = np.fft.fft(padded, axis=1)
filtered = np.fft.ifft(proj_fft * ffilter[:, np.newaxis], axis=1).real
filtered = filtered[:, pad_length:-pad_length, :]
else:
msg = "Method can only be applied to 2 or 3-dimensional stacks"
raise ValueError(msg)
result = stack.deepcopy()
result.data = filtered
return result