t3toolbox.backend.common.ragged_scan#
- t3toolbox.backend.common.ragged_scan(f: Callable[[CarryType, Sequence[NDArray]], Tuple[CarryType, Sequence[NDArray]]], init: CarryType, xs: Sequence[Sequence[NDArray] | NDArray]) Tuple[CarryType, Tuple[Tuple[NDArray, Ellipsis], Ellipsis]]#
Similar to jax.lax.scan, except for ragged-sized arrays https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html