from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import cv2
import numpy as np
from tqdm import tqdm
from .adapters import read_alignment
from .adapters.segmentation import SegmentationAdapterRegistry
from .atlas_map import transform_to_atlas_space
from .utils import (
convert_to_intensity,
discover_image_files,
resize_mask_nearest,
)
from ..io.loaders import number_sections
from ..io.atlas_loader import resolve_atlas
from .reorientation import reorient_volume
@dataclass(frozen=True)
class VolumeConfig:
"""Immutable configuration for section-to-volume projection.
Groups the per-run processing parameters so they can be passed as a
single object to ``_process_one_section`` instead of nine positional
arguments.
"""
segmentation_adapter: object
segmentation_mode: bool
colour_arr: Optional[np.ndarray]
intensity_channel: str
min_intensity: Optional[int]
max_intensity: Optional[int]
scale: float
value_mode: str
@dataclass(frozen=True)
class InterpolationConfig:
"""Immutable configuration for volume interpolation and finalisation.
Groups the post-accumulation parameters so they can be passed as a
single object to ``_finalize_volumes``.
"""
do_interpolation: bool
missing_fill: float
use_atlas_mask: bool
atlas_volume: Optional[np.ndarray]
k: int
batch_size: int
def derive_shape_from_atlas(
*,
atlas_shape: Tuple[int, int, int],
scale: float,
) -> Tuple[int, int, int]:
"""Derive an output shape from atlas shape + scale."""
if scale <= 0:
raise ValueError("scale must be > 0")
return tuple(max(1, int(round(int(s) * float(scale)))) for s in atlas_shape)
def _knn_batch_query(tree, fit_vals, query_pts, k, batch_size, mode):
"""Query *tree* in batches and return interpolated values."""
out_vals = np.empty((query_pts.shape[0],), dtype=np.float32)
for start in tqdm(range(0, query_pts.shape[0], batch_size), desc="filling volume"):
end = min(start + batch_size, query_pts.shape[0])
_, ind = tree.query(query_pts[start:end], k=k)
if k == 1:
out_vals[start:end] = fit_vals[ind]
else:
neigh_vals = fit_vals[ind]
out_vals[start:end] = (
neigh_vals.mean(axis=1) if mode == "mean" else neigh_vals.max(axis=1)
)
return out_vals
def _knn_interpolate_generic(
*,
gv: np.ndarray,
fv: np.ndarray,
atlas_mask: Optional[np.ndarray],
k: int,
batch_size: int,
mode: str = "mean",
) -> np.ndarray:
try:
from scipy.spatial import cKDTree # type: ignore
except Exception as exc: # pragma: no cover
raise ImportError("SciPy is required for do_interpolation=True") from exc
if k < 1:
raise ValueError("k must be >= 1")
fit_mask = fv != 0
if atlas_mask is not None:
target_mask = atlas_mask
fit_mask &= atlas_mask
else:
target_mask = np.ones_like(fv, dtype=bool)
if not (np.any(target_mask) and np.any(fit_mask)):
return gv
fit_pts = np.column_stack(np.nonzero(fit_mask)).astype(np.float32, copy=False)
fit_vals = gv[fit_mask].astype(np.float32, copy=False)
tree = cKDTree(fit_pts)
query_pts = np.column_stack(np.nonzero(target_mask)).astype(np.float32, copy=False)
out_vals = _knn_batch_query(tree, fit_vals, query_pts, k, batch_size, mode)
if atlas_mask is not None:
out = np.zeros_like(gv)
out[target_mask] = out_vals
return out
gv[target_mask] = out_vals
return gv
def _read_section_signal(
seg_path: str,
vol_cfg: VolumeConfig,
):
"""Load an image and return (seg_values, mask, seg_height, seg_width) or None."""
if vol_cfg.segmentation_mode:
seg = vol_cfg.segmentation_adapter.load(seg_path)
else:
seg = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED)
if seg is None:
return None
if not vol_cfg.segmentation_mode:
# Intensity mode
seg_values = convert_to_intensity(seg, vol_cfg.intensity_channel)
if vol_cfg.min_intensity is not None:
seg_values[seg_values < vol_cfg.min_intensity] = 0
if vol_cfg.max_intensity is not None:
seg_values[seg_values > vol_cfg.max_intensity] = 0
mask = (seg_values != 0).astype(np.float32, copy=False)
else:
# Segmentation mode via adapter (supports binary/cellpose/custom)
pixel_id = vol_cfg.colour_arr.tolist() if vol_cfg.colour_arr is not None else None
mask = vol_cfg.segmentation_adapter.create_binary_mask(seg, pixel_id=pixel_id).astype(
np.float32, copy=False
)
seg_values = mask
seg_height, seg_width = seg.shape[:2]
return seg_values, mask, seg_height, seg_width
def _sample_and_deform_plane(
slice_info,
values_reg: np.ndarray,
damage_reg: Optional[np.ndarray],
vol_cfg: VolumeConfig,
):
"""Construct a sampling grid, optionally deform, and remap values.
Returns (sampled_2d, vals_flat, damage_vals, flat_x, flat_y, plane_h, plane_w).
"""
reg_height, reg_width = slice_info.height, slice_info.width
scale = vol_cfg.scale
anch = slice_info.anchoring
u = np.asarray(anch[3:6], dtype=np.float32)
v = np.asarray(anch[6:9], dtype=np.float32)
plane_w = max(1, int(round(float(np.linalg.norm(u)) * float(scale))))
plane_h = max(1, int(round(float(np.linalg.norm(v)) * float(scale))))
yy, xx = np.indices((plane_h, plane_w), dtype=np.float32)
reg_x = (xx + 0.5) * (float(reg_width) / float(plane_w))
reg_y = (yy + 0.5) * (float(reg_height) / float(plane_h))
flat_x = reg_x.reshape(-1)
flat_y = reg_y.reshape(-1)
if slice_info.forward_deformation is not None:
new_x, new_y = slice_info.forward_deformation(flat_x, flat_y)
map_x = new_x.reshape((plane_h, plane_w)).astype(np.float32, copy=False)
map_y = new_y.reshape((plane_h, plane_w)).astype(np.float32, copy=False)
else:
map_x = reg_x.astype(np.float32, copy=False)
map_y = reg_y.astype(np.float32, copy=False)
sampled_2d = cv2.remap(
values_reg,
map_x,
map_y,
interpolation=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
vals_flat = sampled_2d.reshape(-1).astype(np.float32, copy=False)
damage_vals = None
if damage_reg is not None:
sampled_damage = cv2.remap(
damage_reg,
map_x,
map_y,
interpolation=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
damage_vals = sampled_damage.reshape(-1)
return sampled_2d, vals_flat, damage_vals, flat_x, flat_y, plane_h, plane_w
def _accumulate_object_counts(
sampled_2d: np.ndarray,
inb: np.ndarray,
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
seg_nr: int,
out_shape: Tuple[int, int, int],
ov_flat: np.ndarray,
):
"""Count unique 2-D connected components per voxel, accumulating into *ov_flat*."""
sampled_u8 = (sampled_2d != 0).astype(np.uint8)
_n_labels, labels = cv2.connectedComponents(sampled_u8, connectivity=8)
flat_labels = labels.reshape(-1)
obj = flat_labels[inb]
pos = obj != 0
if not np.any(pos):
return
x_pos = x[pos].astype(np.int64, copy=False)
y_pos = y[pos].astype(np.int64, copy=False)
z_pos = z[pos].astype(np.int64, copy=False)
voxel_lin = np.ravel_multi_index(
(x_pos, y_pos, z_pos),
dims=out_shape,
mode="raise",
order="C",
).astype(np.int64, copy=False)
obj_u32 = obj[pos].astype(np.uint64, copy=False)
sec_u64 = np.uint64(seg_nr)
obj_key = (sec_u64 << np.uint64(32)) | obj_u32
pairs = np.empty((voxel_lin.shape[0],), dtype=[("v", "u8"), ("o", "u8")])
pairs["v"] = voxel_lin.astype(np.uint64, copy=False)
pairs["o"] = obj_key
uniq_pairs = np.unique(pairs)
vox_u = uniq_pairs["v"].astype(np.int64, copy=False)
vox_ids, per_vox = np.unique(vox_u, return_counts=True)
np.add.at(ov_flat, vox_ids, per_vox.astype(np.uint32, copy=False))
def _compute_value_volume(gv, fv, ov_flat, value_mode, missing_fill):
"""Derive the value volume from accumulated sums/counts."""
if value_mode == "mean":
out = np.zeros_like(gv, dtype=np.float32)
covered = fv != 0
out[covered] = gv[covered] / fv[covered].astype(np.float32)
if missing_fill is not None and np.any(~covered):
out[~covered] = float(missing_fill)
return out
if value_mode == "object_count":
return ov_flat.reshape(gv.shape).astype(np.float32, copy=False)
return gv
def _resolve_atlas_mask(use_atlas_mask, atlas_volume, gv):
"""Build the atlas-mask array used during interpolation."""
if use_atlas_mask and atlas_volume is not None and atlas_volume.shape == gv.shape:
return atlas_volume != 0
return None
def _finalize_volumes(
gv: np.ndarray,
fv: np.ndarray,
dv: np.ndarray,
ov_flat: Optional[np.ndarray],
vol_cfg: VolumeConfig,
interp_cfg: InterpolationConfig,
):
"""Convert accumulated sums and optionally interpolate."""
gv = _compute_value_volume(gv, fv, ov_flat, vol_cfg.value_mode, interp_cfg.missing_fill)
if interp_cfg.do_interpolation:
atlas_mask = _resolve_atlas_mask(interp_cfg.use_atlas_mask, interp_cfg.atlas_volume, gv)
gv = _knn_interpolate_generic(
gv=gv,
fv=fv,
atlas_mask=atlas_mask,
k=interp_cfg.k,
batch_size=interp_cfg.batch_size,
mode="mean",
)
if np.any(dv > 0):
dv_float = dv.astype(np.float32)
dv_interp = _knn_interpolate_generic(
gv=dv_float,
fv=fv,
atlas_mask=atlas_mask,
k=interp_cfg.k,
batch_size=interp_cfg.batch_size,
mode="max",
)
dv = (dv_interp > 0).astype(np.uint8)
elif interp_cfg.missing_fill is not None and interp_cfg.missing_fill != 0:
gv[fv == 0] = float(interp_cfg.missing_fill)
return (
gv.astype(np.float32, copy=False),
fv.astype(np.uint32, copy=False),
dv.astype(np.uint8, copy=False),
)
def _process_one_section(
seg_path,
slice_by_nr,
vol_cfg: VolumeConfig,
gv,
fv,
dv,
ov_flat,
):
"""Process a single section path and accumulate into the output volumes."""
out_shape = gv.shape
sx, sy, sz = out_shape
seg_nr = int(number_sections([seg_path])[0])
slice_info = slice_by_nr.get(seg_nr)
if not slice_info or not slice_info.anchoring:
return
loaded = _read_section_signal(
seg_path,
vol_cfg,
)
if loaded is None:
return
seg_values, mask, seg_height, seg_width = loaded
reg_height, reg_width = slice_info.height, slice_info.width
# Prepare damage mask in registration space
damage_reg = (
None
if slice_info.damage_mask is None
else resize_mask_nearest(
slice_info.damage_mask.astype(np.uint8),
reg_width,
reg_height,
).astype(np.uint8)
)
# Resample segmentation values into registration space
src = seg_values if vol_cfg.value_mode == "mean" else mask
values_reg = cv2.resize(
src, (reg_width, reg_height), interpolation=cv2.INTER_NEAREST
)
# Sample, deform, and remap the plane
sampled_2d, vals, damage_vals, flat_x, flat_y, plane_h, plane_w = (
_sample_and_deform_plane(slice_info, values_reg, damage_reg, vol_cfg)
)
# Transform flat grid to atlas-space 3-D coordinates
coords = transform_to_atlas_space(slice_info, flat_y, flat_x)
if vol_cfg.scale != 1.0:
coords = coords * float(vol_cfg.scale)
idx = np.rint(coords).astype(np.int64, copy=False)
x, y, z = idx[:, 0], idx[:, 1], idx[:, 2]
inb = (x >= 0) & (x < sx) & (y >= 0) & (y < sy) & (z >= 0) & (z < sz)
if not np.any(inb):
return
x, y, z = x[inb], y[inb], z[inb]
np.add.at(fv, (x, y, z), 1)
if vol_cfg.value_mode != "object_count":
np.add.at(gv, (x, y, z), vals[inb])
if damage_vals is not None:
dv[x, y, z] |= damage_vals[inb].astype(np.uint8)
if ov_flat is not None:
_accumulate_object_counts(sampled_2d, inb, x, y, z, seg_nr, out_shape, ov_flat)
[docs]
def interpolate_volume(
*,
segmentation_folder: str,
alignment_json: str,
colour,
atlas: object,
scale: float = 1.0,
missing_fill: float = np.nan,
do_interpolation: bool = True,
k: int = 5,
batch_size: int = 200_000,
use_atlas_mask: bool = True,
value_mode: str = "pixel_count",
segmentation_format: str = "binary",
segmentation_mode: bool = True,
intensity_channel: str = "grayscale",
min_intensity: Optional[int] = None,
max_intensity: Optional[int] = None,
return_orientation: str = "asr",
):
"""Project section data into atlas-space volumes.
Parameters
----------
segmentation_folder
Path to the folder containing segmentation images or source images.
alignment_json
Path to the registration JSON passed to
:func:`PyNutil.read_alignment`.
colour
Segmentation color or class identifier to extract. Use ``None`` or
``"auto"`` to defer selection to the segmentation adapter.
atlas
Atlas definition used to determine the target volume shape. This may
be a BrainGlobe atlas object or :class:`~PyNutil.AtlasData`.
scale
Isotropic scaling factor applied to the atlas output shape.
missing_fill
Fill value assigned to voxels with no sampled data when interpolation
is disabled or when uncovered voxels remain after processing.
do_interpolation
If ``True``, fill uncovered voxels using k-nearest-neighbor
interpolation.
k
Number of neighbors to use during interpolation.
batch_size
Number of query voxels processed per interpolation batch.
use_atlas_mask
If ``True``, restrict interpolation to voxels inside the atlas mask.
value_mode
Output volume mode. Supported values are ``"pixel_count"``,
``"mean"``, and ``"object_count"``.
segmentation_format
Name of the segmentation adapter to use when ``segmentation_mode`` is
enabled.
segmentation_mode
If ``True``, treat input files as segmentation outputs. If ``False``,
treat them as source images and derive intensities from
``intensity_channel``.
intensity_channel
Image channel to convert to intensity values when
``segmentation_mode=False``.
min_intensity
Optional lower threshold for intensity-mode inputs.
max_intensity
Optional upper threshold for intensity-mode inputs.
Returns
-------
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
A tuple ``(interpolated_volume, frequency_volume, damage_volume)``.
The first element stores the requested value volume, the second stores
per-voxel sampling frequency, and the third is a binary damage mask.
Examples
--------
Build atlas-space volumes from segmentation images:
>>> gv, fv, dv = interpolate_volume(
... segmentation_folder="path/to/segmentations/",
... alignment_json="path/to/alignment.json",
... colour=[0, 0, 0],
... atlas=atlas,
... )
"""
if value_mode not in {"pixel_count", "mean", "object_count"}:
raise ValueError(
"value_mode must be one of 'pixel_count', 'mean', or 'object_count'"
)
atlas_data = resolve_atlas(atlas)
atlas_volume = atlas_data.volume
out_base_shape = tuple(int(x) for x in atlas_volume.shape)
out_shape = derive_shape_from_atlas(atlas_shape=out_base_shape, scale=scale)
registration = read_alignment(alignment_json)
slice_by_nr = {s.section_number: s for s in registration.slices}
seg_paths = discover_image_files(segmentation_folder)
# Accept GUI/settings values like "auto" and defer to adapter auto-detection
# by passing pixel_id=None.
if isinstance(colour, str):
colour_str = colour.strip()
if colour_str.lower() == "auto" or colour_str == "":
colour_arr = None
elif colour_str.isdigit():
colour_arr = np.array([int(colour_str)], dtype=np.uint8)
elif "," in colour_str:
colour_arr = np.array(
[int(x.strip()) for x in colour_str.strip("[]").split(",") if x.strip()],
dtype=np.uint8,
)
else:
raise ValueError(
"colour must be None, 'auto', an int-like string, or a list/tuple of ints"
)
else:
colour_arr = np.array(colour, dtype=np.uint8) if colour is not None else None
vol_cfg = VolumeConfig(
segmentation_adapter=(
SegmentationAdapterRegistry.get(segmentation_format)
if segmentation_mode
else None
),
segmentation_mode=segmentation_mode,
colour_arr=colour_arr,
intensity_channel=intensity_channel,
min_intensity=min_intensity,
max_intensity=max_intensity,
scale=scale,
value_mode=value_mode,
)
interp_cfg = InterpolationConfig(
do_interpolation=do_interpolation,
missing_fill=missing_fill,
use_atlas_mask=use_atlas_mask,
atlas_volume=atlas_volume,
k=k,
batch_size=batch_size,
)
gv = np.zeros(out_shape, dtype=np.float32)
fv = np.zeros(out_shape, dtype=np.uint32)
dv = np.zeros(out_shape, dtype=np.uint8)
ov_flat = (
np.zeros((gv.size,), dtype=np.uint32) if value_mode == "object_count" else None
)
for seg_path in seg_paths:
_process_one_section(seg_path, slice_by_nr, vol_cfg, gv, fv, dv, ov_flat)
gv, fv, dv = _finalize_volumes(gv, fv, dv, ov_flat, vol_cfg, interp_cfg)
if return_orientation != "lpi":
atlas_shape = out_base_shape
gv = reorient_volume(gv, atlas_shape, return_orientation)
fv = reorient_volume(fv, atlas_shape, return_orientation)
dv = reorient_volume(dv, atlas_shape, return_orientation)
return gv, fv, dv