t3toolbox.backend.contractions.Na_Maib_Ni_to_NMb#

t3toolbox.backend.contractions.Na_Maib_Ni_to_NMb(Na: t3toolbox.backend.common.NDArray, Maib: t3toolbox.backend.common.NDArray, Ni: t3toolbox.backend.common.NDArray, use_jax: bool = False) t3toolbox.backend.common.NDArray#

Computes vectorized einsum a,aib,i->b, with vectorization over a and i, or aib, or both.

N and M are the vectorization indices, which may be groups of indices.

Examples

Vectorize over both N and M:

>>> import numpy as np
>>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb
>>> xyz_a = np.random.randn(2,3,4, 10)
>>> uv_aib = np.random.randn(5,6, 10,11,12)
>>> xyz_i = np.random.randn(2,3,4, 11)
>>> NMb = Na_Maib_Ni_to_NMb(xyz_a, uv_aib, xyz_i)
>>> NMb_true = np.einsum('xyza,uvaib,xyzi->xyzuvb', xyz_a, uv_aib, xyz_i)
>>> print(NMb.shape == NMb_true.shape)
True
>>> print(np.linalg.norm(NMb - NMb_true))
3.5869432063566724e-13

Vectorize over N only

>>> import numpy as np
>>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb
>>> xyz_a = np.random.randn(2,3,4, 10)
>>> aib = np.random.randn(10,11,12)
>>> xyz_i = np.random.randn(2,3,4, 11)
>>> Nb = Na_Maib_Ni_to_NMb(xyz_a, aib, xyz_i)
>>> Nb_true = np.einsum('xyza,aib,xyzi->xyzb', xyz_a, aib, xyz_i)
>>> print(Nb.shape == Nb_true.shape)
True
>>> print(np.linalg.norm(Nb - Nb_true))
7.459556385862986e-14

Vectorize over both M only:

>>> import numpy as np
>>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb
>>> a = np.random.randn(10)
>>> uv_aib = np.random.randn(5,6, 10,11,12)
>>> i = np.random.randn(11)
>>> Mb = Na_Maib_Ni_to_NMb(a, uv_aib, i)
>>> Mb_true = np.einsum('a,uvaib,i->uvb', a, uv_aib, i)
>>> print(Mb.shape == Mb_true.shape)
True
>>> print(np.linalg.norm(Mb - Mb_true))
1.254699383909023e-14

No vectorization:

>>> import numpy as np
>>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb
>>> a = np.random.randn(10)
>>> aib = np.random.randn(10,11,12)
>>> i = np.random.randn(11)
>>> b = Na_Maib_Ni_to_NMb(a, aib, i)
>>> b_true = np.einsum('a,aib,i->b', a, aib, i)
>>> print(b.shape == b_true.shape)
True
>>> print(np.linalg.norm(b - b_true))
6.108244889215317e-15