t3toolbox.weighted_tucker_tensor_train.wt3_apply#

t3toolbox.weighted_tucker_tensor_train.wt3_apply(x: WeightedTuckerTensorTrain, vecs: t3toolbox.backend.common.typ.Sequence[t3toolbox.backend.common.NDArray], use_jax: bool = False) t3toolbox.backend.common.NDArray#

Contract a weighted Tucker tensor train with vectors in all indices.

Examples

>>> import numpy as np
>>> import t3toolbox.tucker_tensor_train as t3
>>> import t3toolbox.weighted_tucker_tensor_train as wt3
>>> randn = np.random.randn
>>> x0 = t3.t3_corewise_randn((6,7,8), (5,6,7), (2,3,3,1), stack_shape=(4,))
>>> tucker_vectors = tuple([randn(4, 5), randn(4, 6), randn(4, 7)])
>>> tt_vectors = tuple([randn(4, 2), randn(4, 3), randn(4, 3), randn(4, 1)])
>>> weights = wt3.EdgeVectors(tucker_vectors, tt_vectors)
>>> x = wt3.WeightedTuckerTensorTrain(x0, weights)
>>> vecs = [np.random.randn(2,3, 6), np.random.randn(2,3, 7), np.random.randn(2,3, 8)]
>>> result = wt3.wt3_apply(x, vecs)
>>> result2 = np.einsum('uijk,xyi,xyj,xyk->uxy', x.to_dense(), vecs[0], vecs[1], vecs[2])
>>> print(np.linalg.norm(result - result2))
6.89832231894826e-13