Einstein Summation Convention (einsum)

The Einstein summation convention is a concise notation for tensor operations in mathematics and physics, introduced by Albert Einstein in 1916. In deep learning frameworks, the einsum function leverages this convention to provide a unified, elegant, and powerful way to express various tensor operations.

What is the Einstein Summation Convention

Mathematical Background

In traditional tensor algebra, complex tensor operations typically involve numerous summation symbols and index notations. The Einstein summation convention simplifies expressions by omitting summation symbols. The core rule is: when the same index appears twice in a term, it indicates summation over that index.

For example, traditional matrix multiplication can be written as:

\[C_{ik} = \sum_{j} A_{ij} B_{jk}\]

Using the Einstein summation convention, this simplifies to:

\[C_{ik} = A_{ij} B_{jk}\]

Here, the index j appears twice on the right-hand side (in both A and B), indicating summation over j.

Core Concepts of einsum

The einsum function in Riemann implements this convention, allowing users to describe complex tensor operations through concise string equations. Its advantages include:

  1. Universality: One function can replace many different tensor operations

  2. Readability: The equation string intuitively expresses the mathematical meaning of the operation

  3. Flexibility: Supports tensors of arbitrary dimensions and complex index manipulations

  4. Efficiency: Internal optimizations ensure high-performance computation

einsum Equation String Syntax

Basic Syntax Structure

The einsum equation string follows this format:

"input1_indices,input2_indices,...->output_indices"

Where:

  • Input indices: Letters representing the dimensions of corresponding input tensors, e.g., ij for a 2D tensor (matrix)

  • Output indices: Specify the dimensions of the output tensor; if omitted, indices are arranged in alphabetical order

  • Ellipsis ``…``: Represents any number of batch dimensions

  • Repeated indices: Indicate summation (contraction) over that dimension

Detailed Index Rules

  1. Unique indices: If an index appears in only one input and also in the output, that dimension is preserved

  2. Repeated indices: If an index appears in multiple inputs, it indicates element-wise multiplication followed by summation over those dimensions

  3. Missing indices: If an index from the input does not appear in the output, it indicates summation reduction over that dimension

Example Analysis

# Matrix multiplication: ij,jk->ik
# i: row index of A, j: column index of A/row index of B (summed), k: column index of B
# Result C has dimensions (i, k)
C = rm.einsum('ij,jk->ik', A, B)

# Batch matrix multiplication: ...ij,...jk->...ik
# ... represents any batch dimensions
C = rm.einsum('...ij,...jk->...ik', A, B)

# Trace operation: ii->
# i appears twice, indicating summation of diagonal elements
trace = rm.einsum('ii->', A)

# Diagonal extraction: ii->i
# Preserves diagonal elements, result is a vector
diag = rm.einsum('ii->i', A)

Classification of einsum Computation Scenarios

The following tables detail the various computation scenarios that einsum can replace:

Basic Matrix Operations

Basic Matrix Operations

Operation Type

Mathematical Description

einsum Equation

Equivalent Function

Matrix Multiplication

\(C_{ik} = \sum_j A_{ij} B_{jk}\)

ij,jk->ik

rm.matmul(A, B)

Batch Matrix Multiplication

\(C_{bik} = \sum_j A_{bij} B_{bjk}\)

bij,bjk->bik

rm.matmul(A, B)

General Batch Matrix Multiplication

Supports arbitrary batch dimensions

...ij,...jk->...ik

rm.matmul(A, B)

Vector Dot Product

\(c = \sum_i a_i b_i\)

i,i->

rm.dot(a, b)

Vector Outer Product

\(C_{ij} = a_i b_j\)

i,j->ij

rm.outer(a, b)

Matrix Property Extraction

Matrix Property Extraction

Operation Type

Mathematical Description

einsum Equation

Equivalent Function

Matrix Trace

\(\text{tr}(A) = \sum_i A_{ii}\)

ii->

rm.trace(A)

Diagonal Extraction

\(\text{diag}(A)_i = A_{ii}\)

ii->i

rm.diag(A)

Batch Matrix Trace

\(\text{tr}(A_b) = \sum_i A_{bii}\)

bii->b

rm.trace(A)

Batch Diagonal Extraction

\(\text{diag}(A_b)_i = A_{bii}\)

bii->bi

rm.diag(A)

Transpose and Dimension Permutation

Transpose and Dimension Permutation

Operation Type

Mathematical Description

einsum Equation

Equivalent Function

Matrix Transpose

\(C_{ji} = A_{ij}\)

ij->ji

A.T or rm.transpose(A)

High-Dimensional Transpose

\(C_{jki} = A_{ijk}\)

ijk->jki

rm.permute(A, (1, 2, 0))

Batch Transpose

\(C_{bji} = A_{bij}\)

...ij->...ji

A.mT

Tensor Contraction and Summation

Tensor Contraction and Summation

Operation Type

Mathematical Description

einsum Equation

Equivalent Function

Sum All Elements

\(s = \sum_{i,j} A_{ij}\)

ij->

rm.sum(A)

Sum Along Rows

\(s_i = \sum_j A_{ij}\)

ij->i

rm.sum(A, dim=1)

Sum Along Columns

\(s_j = \sum_i A_{ij}\)

ij->j

rm.sum(A, dim=0)

Tensor Contraction

\(C_{ijm} = \sum_{k,l} A_{ijkl} B_{jklm}\)

ijkl,jklm->ijm

No direct equivalent

Self-Contraction

\(C_i = \sum_j A_{ij} B_{ij}\)

ij,ij->i

(A * B).sum(dim=1)

Special Matrix Operations

Special Matrix Operations

Operation Type

Mathematical Description

einsum Equation

Equivalent Function

Hadamard Product

\(C_{ij} = A_{ij} B_{ij}\)

ij,ij->ij

A * B

Frobenius Inner Product

\(\langle A, B \rangle_F = \sum_{i,j} A_{ij} B_{ij}\)

ij,ij->

(A * B).sum()

Kronecker Product

\(C_{ikjl} = A_{ij} B_{kl}\)

ij,kl->ikjl

rm.kron(A, B)

Identity Copy

\(C_{ij} = A_{ij}\)

ij->ij

A.clone()

Multi-Operand Operations

Multi-Operand Operations

Operation Type

Mathematical Description

einsum Equation

Description

Three-Operand Chain

\(C_{il} = \sum_{j,k} A_{ij} B_{jk} C_{kl}\)

ij,jk,kl->il

Sequential matrix multiplication

Four-Operand Chain

\(C_{im} = \sum_{j,k,l} A_{ij} B_{jk} C_{kl} D_{lm}\)

ij,jk,kl,lm->im

Long chain matrix multiplication

Multi-Operand Mixed

\(C_i = \sum_{j,k} A_{ij} B_{jk} C_{ik}\)

ij,jk,ik->i

Complex mixed operation

Batch Three-Operand

Supports arbitrary batch dimensions

...ij,...jk,...kl->...il

Batch chain multiplication

Repeated Index Operations

Repeated Index Operations

Operation Type

Mathematical Description

einsum Equation

Description

First Two Indices Repeated

\(C_j = \sum_{i} A_{iij}\)

iij->j

Extract specific diagonal

Multiple Indices Repeated

\(C_{ij} = \sum_{k,l} A_{iijj}\)

iijj->ij

High-dimensional diagonal extraction

Non-Contiguous Index Repeat

\(s = \sum_{i,j} A_{ijji}\)

ijji->

Anti-diagonal summation

1D Vector Operations

1D Vector Operations

Operation Type

Mathematical Description

einsum Equation

Equivalent Function

Vector Dot Product

\(c = \sum_i a_i b_i\)

i,i->

rm.dot(a, b)

Vector Outer Product

\(C_{ij} = a_i b_j\)

i,j->ij

rm.outer(a, b)

Matrix-Vector Multiplication

\(c_i = \sum_j A_{ij} b_j\)

ij,j->i

rm.matmul(A, b)

Vector-Matrix Multiplication

\(c_j = \sum_i a_i A_{ij}\)

i,ij->j

rm.matmul(a, A)

Batch Matrix-Vector Multiplication

\(C_{bi} = \sum_j A_{bij} b_{bj}\)

bij,bj->bi

rm.matmul(A, b.unsqueeze(-1)).squeeze(-1)

einsum Usage Examples

Example 1: Matrix Multiplication

import riemann as rm

# Create matrices
A = rm.tensor([[1, 2], [3, 4], [5, 6]])  # 3x2
B = rm.tensor([[7, 8, 9], [10, 11, 12]])  # 2x3

# Matrix multiplication: (3x2) @ (2x3) = (3x3)
C = rm.einsum('ij,jk->ik', A, B)
print("Matrix multiplication result:")
print(C)
# Output:
# tensor([[ 27,  30,  33],
#         [ 61,  68,  75],
#         [ 95, 106, 117]])

Example 2: Batch Matrix Multiplication

import riemann as rm

# Create batch matrices (2 matrices of 3x4)
A = rm.randn(2, 3, 4)
# Create batch matrices (2 matrices of 4x5)
B = rm.randn(2, 4, 5)

# Batch matrix multiplication
C = rm.einsum('bij,bjk->bik', A, B)
print(f"Batch matrix multiplication result shape: {C.shape}")  # (2, 3, 5)

# Use ellipsis to support more batch dimensions
A = rm.randn(2, 3, 4, 5)  # 2x3 matrices of 4x5
B = rm.randn(2, 3, 5, 6)  # 2x3 matrices of 5x6
C = rm.einsum('...ij,...jk->...ik', A, B)
print(f"General batch multiplication result shape: {C.shape}")  # (2, 3, 4, 6)

Example 3: Trace and Diagonal Operations

import riemann as rm

# Create square matrix
A = rm.tensor([[1, 2, 3],
               [4, 5, 6],
               [7, 8, 9]])

# Compute trace (sum of diagonal elements)
trace = rm.einsum('ii->', A)
print(f"Matrix trace: {trace}")  # 1 + 5 + 9 = 15

# Extract diagonal elements
diag = rm.einsum('ii->i', A)
print(f"Diagonal elements: {diag}")  # [1, 5, 9]

# Batch matrix trace
batch_A = rm.randn(4, 3, 3)  # 4 matrices of 3x3
batch_trace = rm.einsum('bii->b', batch_A)
print(f"Batch trace shape: {batch_trace.shape}")  # (4,)

Example 4: Transpose and Dimension Permutation

import riemann as rm

# Create matrix
A = rm.tensor([[1, 2, 3],
               [4, 5, 6]])

# Matrix transpose
A_T = rm.einsum('ij->ji', A)
print("Transpose result:")
print(A_T)
# Output:
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

# High-dimensional transpose
B = rm.randn(2, 3, 4)
B_perm = rm.einsum('ijk->jki', B)
print(f"High-dimensional transpose shape: {B_perm.shape}")  # (3, 4, 2)

Example 5: Vector Operations

import riemann as rm

# Create vectors
a = rm.tensor([1, 2, 3])
b = rm.tensor([4, 5, 6])

# Vector dot product
dot_product = rm.einsum('i,i->', a, b)
print(f"Dot product: {dot_product}")  # 1*4 + 2*5 + 3*6 = 32

# Vector outer product
outer_product = rm.einsum('i,j->ij', a, b)
print("Outer product result:")
print(outer_product)
# Output:
# tensor([[ 4,  5,  6],
#         [ 8, 10, 12],
#         [12, 15, 18]])

# Matrix-vector multiplication
A = rm.tensor([[1, 2, 3],
               [4, 5, 6]])
c = rm.einsum('ij,j->i', A, a)
print(f"Matrix-vector multiplication: {c}")  # [14, 32]

Example 6: Tensor Contraction

import riemann as rm

# Create 3D tensors
A = rm.randn(2, 3, 4)
B = rm.randn(3, 4, 5)

# Tensor contraction: sum over dimensions 1 and 2
C = rm.einsum('ijk,jkl->il', A, B)
print(f"Tensor contraction result shape: {C.shape}")  # (2, 5)

# More complex contraction
D = rm.randn(2, 3, 4, 5)
E = rm.randn(3, 4, 5, 6)
F = rm.einsum('ijkl,jklm->im', D, E)
print(f"Complex contraction result shape: {F.shape}")  # (2, 6)

Example 7: Hadamard Product and Frobenius Inner Product

import riemann as rm

# Create matrices
A = rm.tensor([[1, 2],
               [3, 4]])
B = rm.tensor([[5, 6],
               [7, 8]])

# Hadamard product (element-wise multiplication)
hadamard = rm.einsum('ij,ij->ij', A, B)
print("Hadamard product:")
print(hadamard)
# Output:
# tensor([[ 5, 12],
#         [21, 32]])

# Frobenius inner product
frobenius = rm.einsum('ij,ij->', A, B)
print(f"Frobenius inner product: {frobenius}")  # 5 + 12 + 21 + 32 = 70

Example 8: Multi-Operand Operations

import riemann as rm

# Create multiple matrices
A = rm.randn(3, 4)
B = rm.randn(4, 5)
C = rm.randn(5, 6)
D = rm.randn(6, 7)

# Four-operand chain multiplication
result = rm.einsum('ij,jk,kl,lm->im', A, B, C, D)
print(f"Four-operand chain result shape: {result.shape}")  # (3, 7)

# Equivalent to:
# temp1 = rm.matmul(A, B)
# temp2 = rm.matmul(temp1, C)
# result = rm.matmul(temp2, D)

Example 9: einsum with Gradient Tracking

import riemann as rm

# Create tensors requiring gradients
A = rm.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
B = rm.tensor([[5.0, 6.0], [7.0, 8.0]], requires_grad=True)

# Perform einsum operation
C = rm.einsum('ij,jk->ik', A, B)

# Compute loss and backpropagate
loss = C.sum()
loss.backward()

print("Gradient of A:")
print(A.grad)
print("Gradient of B:")
print(B.grad)

Example 10: Implicit Output (Omitting Output Indices)

import riemann as rm

A = rm.tensor([[1, 2], [3, 4]])
B = rm.tensor([[5, 6], [7, 8]])

# Implicit output: omit the part after ->
# Result indices are arranged in alphabetical order
C = rm.einsum('ij,jk', A, B)  # Equivalent to 'ij,jk->ik'
print("Implicit output result:")
print(C)

# Batch implicit output
A = rm.randn(2, 3, 4)
B = rm.randn(2, 4, 5)
C = rm.einsum('...ij,...jk', A, B)  # Equivalent to '...ij,...jk->...ik'
print(f"Batch implicit output shape: {C.shape}")  # (2, 3, 5)

einsum Performance Optimization Tips

  1. Prefer simple equations: For common matrix multiplications, using rm.matmul directly may be more efficient

  2. Avoid unnecessary copies: einsum returns views rather than copies whenever possible

  3. Batch operations over loops: Use ... to represent batch dimensions instead of explicit loops

  4. Chain operation fusion: Multiple matrix multiplications can be combined into one einsum call, reducing intermediate results

  5. Pre-compiled equations: For repeatedly used equations, einsum automatically caches optimizations

Notes

  1. Index letter limitation: Indices use lowercase letters (a-z), supporting up to 26 different indices

  2. Dimension matching: Dimensions of repeated indices must be consistent

  3. Device consistency: All input tensors must be on the same device

  4. Data types: einsum follows standard type promotion rules

  5. Gradient tracking: Supports automatic differentiation and normal gradient computation