t3toolbox.weighted_tucker_tensor_train.wt3_apply#
- t3toolbox.weighted_tucker_tensor_train.wt3_apply(x: WeightedTuckerTensorTrain, vecs: t3toolbox.backend.common.typ.Sequence[t3toolbox.backend.common.NDArray], use_jax: bool = False) t3toolbox.backend.common.NDArray#
Contract a weighted Tucker tensor train with vectors in all indices.
Examples
>>> import numpy as np >>> import t3toolbox.tucker_tensor_train as t3 >>> import t3toolbox.weighted_tucker_tensor_train as wt3 >>> randn = np.random.randn >>> x0 = t3.t3_corewise_randn((6,7,8), (5,6,7), (2,3,3,1), stack_shape=(4,)) >>> tucker_vectors = tuple([randn(4, 5), randn(4, 6), randn(4, 7)]) >>> tt_vectors = tuple([randn(4, 2), randn(4, 3), randn(4, 3), randn(4, 1)]) >>> weights = wt3.EdgeVectors(tucker_vectors, tt_vectors) >>> x = wt3.WeightedTuckerTensorTrain(x0, weights) >>> vecs = [np.random.randn(2,3, 6), np.random.randn(2,3, 7), np.random.randn(2,3, 8)] >>> result = wt3.wt3_apply(x, vecs) >>> result2 = np.einsum('uijk,xyi,xyj,xyk->uxy', x.to_dense(), vecs[0], vecs[1], vecs[2]) >>> print(np.linalg.norm(result - result2)) 6.89832231894826e-13