t3toolbox.backend.contractions.MNa_Maib_MNb_to_MNi#

t3toolbox.backend.contractions.MNa_Maib_MNb_to_MNi(MNa: t3toolbox.backend.common.NDArray, Maib: t3toolbox.backend.common.NDArray, MNb: t3toolbox.backend.common.NDArray, use_jax: bool = False) t3toolbox.backend.common.NDArray#

Computes contraction MNa,Maib,MNb->MNi.

N and M may be individual indices, groups of indices, or nonexistent.

Examples

Vectorize over both N and M:

>>> import numpy as np
>>> from t3toolbox.utils.contractions import MNa_Maib_MNb_to_MNi
>>> MNa = np.random.randn(2,3, 4,5,6, 10)
>>> Maib = np.random.randn(2,3, 10,11,12)
>>> MNb = np.random.randn(2,3, 4,5,6, 12)
>>> result = MNa_Maib_MNb_to_MNi(MNa, Maib, MNb)
>>> result2 = np.einsum('uvxyza,uvaib,uvxyzb->uvxyzi', MNa, Maib, MNb)
>>> print(result.shape == result2.shape)
True
>>> print(np.linalg.norm(result - result2))
0.0

Vectorize over N only

>>> import numpy as np
>>> from t3toolbox.utils.contractions import MNa_Maib_MNb_to_MNi
>>> MNa = np.random.randn(4,5,6, 10)
>>> Maib = np.random.randn(10,11,12)
>>> MNb = np.random.randn(4,5,6, 12)
>>> result = MNa_Maib_MNb_to_MNi(MNa, Maib, MNb)
>>> result2 = np.einsum('xyza,aib,xyzb->xyzi', MNa, Maib, MNb)
>>> print(result.shape == result2.shape)
True
>>> print(np.linalg.norm(result - result2))
0.0

Vectorize over both M only:

>>> import numpy as np
>>> from t3toolbox.utils.contractions import MNa_Maib_MNb_to_MNi
>>> MNa = np.random.randn(2,3, 10)
>>> Maib = np.random.randn(2,3, 10,11,12)
>>> MNb = np.random.randn(2,3, 12)
>>> result = MNa_Maib_MNb_to_MNi(MNa, Maib, MNb)
>>> result2 = np.einsum('uva,uvaib,uvb->uvi', MNa, Maib, MNb)
>>> print(result.shape == result2.shape)
True
>>> print(np.linalg.norm(result - result2))
0.0

No vectorization:

>>> import numpy as np
>>> from t3toolbox.utils.contractions import MNa_Maib_MNb_to_MNi
>>> MNa = np.random.randn(10)
>>> Maib = np.random.randn(10,11,12)
>>> MNb = np.random.randn(12)
>>> result = MNa_Maib_MNb_to_MNi(MNa, Maib, MNb)
>>> result2 = np.einsum('a,aib,b->i', MNa, Maib, MNb)
>>> print(result.shape == result2.shape)
True
>>> print(np.linalg.norm(result - result2))
0.0