t3toolbox.backend.contractions.Na_Maib_Ni_to_NMb#
- t3toolbox.backend.contractions.Na_Maib_Ni_to_NMb(Na: t3toolbox.backend.common.NDArray, Maib: t3toolbox.backend.common.NDArray, Ni: t3toolbox.backend.common.NDArray, use_jax: bool = False) t3toolbox.backend.common.NDArray#
Computes vectorized einsum a,aib,i->b, with vectorization over a and i, or aib, or both.
N and M are the vectorization indices, which may be groups of indices.
Examples
Vectorize over both N and M:
>>> import numpy as np >>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb >>> xyz_a = np.random.randn(2,3,4, 10) >>> uv_aib = np.random.randn(5,6, 10,11,12) >>> xyz_i = np.random.randn(2,3,4, 11) >>> NMb = Na_Maib_Ni_to_NMb(xyz_a, uv_aib, xyz_i) >>> NMb_true = np.einsum('xyza,uvaib,xyzi->xyzuvb', xyz_a, uv_aib, xyz_i) >>> print(NMb.shape == NMb_true.shape) True >>> print(np.linalg.norm(NMb - NMb_true)) 3.5869432063566724e-13
Vectorize over N only
>>> import numpy as np >>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb >>> xyz_a = np.random.randn(2,3,4, 10) >>> aib = np.random.randn(10,11,12) >>> xyz_i = np.random.randn(2,3,4, 11) >>> Nb = Na_Maib_Ni_to_NMb(xyz_a, aib, xyz_i) >>> Nb_true = np.einsum('xyza,aib,xyzi->xyzb', xyz_a, aib, xyz_i) >>> print(Nb.shape == Nb_true.shape) True >>> print(np.linalg.norm(Nb - Nb_true)) 7.459556385862986e-14
Vectorize over both M only:
>>> import numpy as np >>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb >>> a = np.random.randn(10) >>> uv_aib = np.random.randn(5,6, 10,11,12) >>> i = np.random.randn(11) >>> Mb = Na_Maib_Ni_to_NMb(a, uv_aib, i) >>> Mb_true = np.einsum('a,uvaib,i->uvb', a, uv_aib, i) >>> print(Mb.shape == Mb_true.shape) True >>> print(np.linalg.norm(Mb - Mb_true)) 1.254699383909023e-14
No vectorization:
>>> import numpy as np >>> from t3toolbox.utils.contractions import Na_Maib_Ni_to_NMb >>> a = np.random.randn(10) >>> aib = np.random.randn(10,11,12) >>> i = np.random.randn(11) >>> b = Na_Maib_Ni_to_NMb(a, aib, i) >>> b_true = np.einsum('a,aib,i->b', a, aib, i) >>> print(b.shape == b_true.shape) True >>> print(np.linalg.norm(b - b_true)) 6.108244889215317e-15