t3toolbox.weighted_tucker_tensor_train.wt3_probe#

t3toolbox.weighted_tucker_tensor_train.wt3_probe(ww: t3toolbox.backend.common.typ.Sequence[t3toolbox.backend.common.NDArray], x: WeightedTuckerTensorTrain, use_jax: bool = False) t3toolbox.backend.common.typ.Sequence[t3toolbox.backend.common.NDArray]#

Probe a TuckerTensorTrain.

Examples

>>> import numpy as np
>>> import t3toolbox.tucker_tensor_train as t3
>>> import t3toolbox.weighted_tucker_tensor_train as wt3
>>> import t3toolbox.backend.probing as probing
>>> randn = np.random.randn
>>> x0 = t3.t3_corewise_randn((16,17,18), (5,6,7), (2,3,3,1), stack_shape=(4,))
>>> tucker_vectors = tuple([randn(4, 5), randn(4, 6), randn(4, 7)])
>>> tt_vectors = tuple([randn(4, 2), randn(4, 3), randn(4, 3), randn(4, 1)])
>>> weights = wt3.EdgeVectors(tucker_vectors, tt_vectors)
>>> x = wt3.WeightedTuckerTensorTrain(x0, weights)
>>> ww = (np.random.randn(2,3, 16), np.random.randn(2,3, 17), np.random.randn(2,3, 18))
>>> zz = wt3.wt3_probe(ww, x)
>>> x_dense = x.to_dense()
>>> zz2 = probing.probe_dense(ww, x_dense)
>>> print([np.linalg.norm(z - z2) for z, z2 in zip(zz, zz2)])
[8.629565831373193e-12, 3.713823926468943e-12, 1.2786236870880314e-11]