t3toolbox.weighted_tucker_tensor_train.wt3_entries#

t3toolbox.weighted_tucker_tensor_train.wt3_entries(x: WeightedTuckerTensorTrain, index: t3toolbox.backend.common.NDArray, use_jax: bool = False) t3toolbox.backend.common.NDArray#

Compute an entry (or multiple entries) of a weighted Tucker tensor train.

Examples

>>> import numpy as np
>>> import t3toolbox.tucker_tensor_train as t3
>>> import t3toolbox.weighted_tucker_tensor_train as wt3
>>> randn = np.random.randn
>>> x0 = t3.t3_corewise_randn((16,17,18), (5,6,7), (2,3,3,1), stack_shape=(4,))
>>> tucker_vectors = tuple([randn(4, 5), randn(4, 6), randn(4, 7)])
>>> tt_vectors = tuple([randn(4, 2), randn(4, 3), randn(4, 3), randn(4, 1)])
>>> weights = wt3.EdgeVectors(tucker_vectors, tt_vectors)
>>> x = wt3.WeightedTuckerTensorTrain(x0, weights)
>>> index = [[9,0], [4,0], [7,0]] # get entries (9,4,7) and (0,0,0)
>>> entries = wt3.wt3_entries(x, index)
>>> x_dense = x.to_dense()
>>> entries2 = np.moveaxis(np.array([x_dense[:, 9,4,7], x_dense[:, 0,0,0]]), 0,1)
>>> print(np.linalg.norm(entries - entries2))
2.8718552890331766e-14