from ect import ECT
from .embed_complex import EmbeddedGraph, EmbeddedCW
from .directions import Directions
from .results import ECTResult
from typing import Optional, Union
import numpy as np
from numba import njit
[docs]
class DECT(ECT):
"""
A class to calculate the Differentiable Euler Characteristic Transform (DECT)
"""
[docs]
def __init__(
self,
directions: Optional[Directions] = None,
num_dirs: Optional[int] = None,
num_thresh: Optional[int] = None,
bound_radius: Optional[float] = None,
thresholds: Optional[np.ndarray] = None,
dtype=np.float32,
scale: float = 10.0,
):
"""Initialize DECT calculator"""
super().__init__(
directions, num_dirs, num_thresh, bound_radius, thresholds, dtype
)
self.scale = scale
@staticmethod
@njit(fastmath=True)
def _compute_directional_transform(
simplex_projections_list, thresholds, dtype=np.float32, scale=10.0
):
"""Compute DECT using sigmoid for smooth transitions"""
num_dir = simplex_projections_list[0].shape[1]
num_thresh = len(thresholds)
output = np.zeros((num_dir, num_thresh), dtype=dtype)
for i, simplex_heights in enumerate(simplex_projections_list):
for d in range(num_dir):
for t, thresh in enumerate(thresholds):
diff = scale * (simplex_heights[:, d] - thresh)
sigmoid = 1 / (1 + np.exp(-diff))
sign = -1 if i % 2 == 0 else 1
output[d, t] += sign * np.sum(sigmoid)
return output
[docs]
def calculate(
self,
graph: Union[EmbeddedGraph, EmbeddedCW],
scale: Optional[float] = None,
theta: Optional[float] = None,
override_bound_radius: Optional[float] = None,
) -> ECTResult:
"""Calculate Differentiable Euler Characteristic Transform (DECT)"""
self._ensure_directions(graph.dim, theta)
self._ensure_thresholds(graph, override_bound_radius)
directions = (
self.directions if theta is None else Directions.from_angles([theta])
)
simplex_projections = self._compute_simplex_projections(graph, directions)
scale = self.scale if scale is None else scale
ect_matrix = self._compute_directional_transform(
simplex_projections, self.thresholds, self.dtype, scale
)
return ECTResult(ect_matrix, directions, self.thresholds)