Source code for ect.dect

from ect import ECT
from .embed_complex import EmbeddedGraph, EmbeddedCW
from .directions import Directions
from .results import ECTResult
from typing import Optional, Union
import numpy as np
from numba import njit


[docs] class DECT(ECT): """ A class to calculate the Differentiable Euler Characteristic Transform (DECT) """
[docs] def __init__( self, directions: Optional[Directions] = None, num_dirs: Optional[int] = None, num_thresh: Optional[int] = None, bound_radius: Optional[float] = None, thresholds: Optional[np.ndarray] = None, dtype=np.float32, scale: float = 10.0, ): """Initialize DECT calculator""" super().__init__( directions, num_dirs, num_thresh, bound_radius, thresholds, dtype ) self.scale = scale
@staticmethod @njit(fastmath=True) def _compute_directional_transform( simplex_projections_list, thresholds, dtype=np.float32, scale=10.0 ): """Compute DECT using sigmoid for smooth transitions""" num_dir = simplex_projections_list[0].shape[1] num_thresh = len(thresholds) output = np.zeros((num_dir, num_thresh), dtype=dtype) for i, simplex_heights in enumerate(simplex_projections_list): for d in range(num_dir): for t, thresh in enumerate(thresholds): diff = scale * (simplex_heights[:, d] - thresh) sigmoid = 1 / (1 + np.exp(-diff)) sign = -1 if i % 2 == 0 else 1 output[d, t] += sign * np.sum(sigmoid) return output
[docs] def calculate( self, graph: Union[EmbeddedGraph, EmbeddedCW], scale: Optional[float] = None, theta: Optional[float] = None, override_bound_radius: Optional[float] = None, ) -> ECTResult: """Calculate Differentiable Euler Characteristic Transform (DECT)""" self._ensure_directions(graph.dim, theta) self._ensure_thresholds(graph, override_bound_radius) directions = ( self.directions if theta is None else Directions.from_angles([theta]) ) simplex_projections = self._compute_simplex_projections(graph, directions) scale = self.scale if scale is None else scale ect_matrix = self._compute_directional_transform( simplex_projections, self.thresholds, self.dtype, scale ) return ECTResult(ect_matrix, directions, self.thresholds)