t3toolbox.weighted_tucker_tensor_train.wt3_probe#
- t3toolbox.weighted_tucker_tensor_train.wt3_probe(ww: t3toolbox.backend.common.typ.Sequence[t3toolbox.backend.common.NDArray], x: WeightedTuckerTensorTrain, use_jax: bool = False) t3toolbox.backend.common.typ.Sequence[t3toolbox.backend.common.NDArray]#
Probe a TuckerTensorTrain.
Examples
>>> import numpy as np >>> import t3toolbox.tucker_tensor_train as t3 >>> import t3toolbox.weighted_tucker_tensor_train as wt3 >>> import t3toolbox.backend.probing as probing >>> randn = np.random.randn >>> x0 = t3.t3_corewise_randn((16,17,18), (5,6,7), (2,3,3,1), stack_shape=(4,)) >>> tucker_vectors = tuple([randn(4, 5), randn(4, 6), randn(4, 7)]) >>> tt_vectors = tuple([randn(4, 2), randn(4, 3), randn(4, 3), randn(4, 1)]) >>> weights = wt3.EdgeVectors(tucker_vectors, tt_vectors) >>> x = wt3.WeightedTuckerTensorTrain(x0, weights) >>> ww = (np.random.randn(2,3, 16), np.random.randn(2,3, 17), np.random.randn(2,3, 18)) >>> zz = wt3.wt3_probe(ww, x) >>> x_dense = x.to_dense() >>> zz2 = probing.probe_dense(ww, x_dense) >>> print([np.linalg.norm(z - z2) for z, z2 in zip(zz, zz2)]) [8.629565831373193e-12, 3.713823926468943e-12, 1.2786236870880314e-11]