Einstein Summation Convention (einsum)
The Einstein summation convention is a concise notation for tensor operations in mathematics and physics, introduced by Albert Einstein in 1916. In deep learning frameworks, the einsum function leverages this convention to provide a unified, elegant, and powerful way to express various tensor operations.
What is the Einstein Summation Convention
Mathematical Background
In traditional tensor algebra, complex tensor operations typically involve numerous summation symbols and index notations. The Einstein summation convention simplifies expressions by omitting summation symbols. The core rule is: when the same index appears twice in a term, it indicates summation over that index.
For example, traditional matrix multiplication can be written as:
Using the Einstein summation convention, this simplifies to:
Here, the index j appears twice on the right-hand side (in both A and B), indicating summation over j.
Core Concepts of einsum
The einsum function in Riemann implements this convention, allowing users to describe complex tensor operations through concise string equations. Its advantages include:
Universality: One function can replace many different tensor operations
Readability: The equation string intuitively expresses the mathematical meaning of the operation
Flexibility: Supports tensors of arbitrary dimensions and complex index manipulations
Efficiency: Internal optimizations ensure high-performance computation
einsum Equation String Syntax
Basic Syntax Structure
The einsum equation string follows this format:
"input1_indices,input2_indices,...->output_indices"
Where:
Input indices: Letters representing the dimensions of corresponding input tensors, e.g.,
ijfor a 2D tensor (matrix)Output indices: Specify the dimensions of the output tensor; if omitted, indices are arranged in alphabetical order
Ellipsis ``…``: Represents any number of batch dimensions
Repeated indices: Indicate summation (contraction) over that dimension
Detailed Index Rules
Unique indices: If an index appears in only one input and also in the output, that dimension is preserved
Repeated indices: If an index appears in multiple inputs, it indicates element-wise multiplication followed by summation over those dimensions
Missing indices: If an index from the input does not appear in the output, it indicates summation reduction over that dimension
Example Analysis
# Matrix multiplication: ij,jk->ik
# i: row index of A, j: column index of A/row index of B (summed), k: column index of B
# Result C has dimensions (i, k)
C = rm.einsum('ij,jk->ik', A, B)
# Batch matrix multiplication: ...ij,...jk->...ik
# ... represents any batch dimensions
C = rm.einsum('...ij,...jk->...ik', A, B)
# Trace operation: ii->
# i appears twice, indicating summation of diagonal elements
trace = rm.einsum('ii->', A)
# Diagonal extraction: ii->i
# Preserves diagonal elements, result is a vector
diag = rm.einsum('ii->i', A)
Classification of einsum Computation Scenarios
The following tables detail the various computation scenarios that einsum can replace:
Basic Matrix Operations
Operation Type |
Mathematical Description |
einsum Equation |
Equivalent Function |
|---|---|---|---|
Matrix Multiplication |
\(C_{ik} = \sum_j A_{ij} B_{jk}\) |
|
|
Batch Matrix Multiplication |
\(C_{bik} = \sum_j A_{bij} B_{bjk}\) |
|
|
General Batch Matrix Multiplication |
Supports arbitrary batch dimensions |
|
|
Vector Dot Product |
\(c = \sum_i a_i b_i\) |
|
|
Vector Outer Product |
\(C_{ij} = a_i b_j\) |
|
|
Matrix Property Extraction
Operation Type |
Mathematical Description |
einsum Equation |
Equivalent Function |
|---|---|---|---|
Matrix Trace |
\(\text{tr}(A) = \sum_i A_{ii}\) |
|
|
Diagonal Extraction |
\(\text{diag}(A)_i = A_{ii}\) |
|
|
Batch Matrix Trace |
\(\text{tr}(A_b) = \sum_i A_{bii}\) |
|
|
Batch Diagonal Extraction |
\(\text{diag}(A_b)_i = A_{bii}\) |
|
|
Transpose and Dimension Permutation
Operation Type |
Mathematical Description |
einsum Equation |
Equivalent Function |
|---|---|---|---|
Matrix Transpose |
\(C_{ji} = A_{ij}\) |
|
|
High-Dimensional Transpose |
\(C_{jki} = A_{ijk}\) |
|
|
Batch Transpose |
\(C_{bji} = A_{bij}\) |
|
|
Tensor Contraction and Summation
Operation Type |
Mathematical Description |
einsum Equation |
Equivalent Function |
|---|---|---|---|
Sum All Elements |
\(s = \sum_{i,j} A_{ij}\) |
|
|
Sum Along Rows |
\(s_i = \sum_j A_{ij}\) |
|
|
Sum Along Columns |
\(s_j = \sum_i A_{ij}\) |
|
|
Tensor Contraction |
\(C_{ijm} = \sum_{k,l} A_{ijkl} B_{jklm}\) |
|
No direct equivalent |
Self-Contraction |
\(C_i = \sum_j A_{ij} B_{ij}\) |
|
|
Special Matrix Operations
Operation Type |
Mathematical Description |
einsum Equation |
Equivalent Function |
|---|---|---|---|
Hadamard Product |
\(C_{ij} = A_{ij} B_{ij}\) |
|
|
Frobenius Inner Product |
\(\langle A, B \rangle_F = \sum_{i,j} A_{ij} B_{ij}\) |
|
|
Kronecker Product |
\(C_{ikjl} = A_{ij} B_{kl}\) |
|
|
Identity Copy |
\(C_{ij} = A_{ij}\) |
|
|
Multi-Operand Operations
Operation Type |
Mathematical Description |
einsum Equation |
Description |
|---|---|---|---|
Three-Operand Chain |
\(C_{il} = \sum_{j,k} A_{ij} B_{jk} C_{kl}\) |
|
Sequential matrix multiplication |
Four-Operand Chain |
\(C_{im} = \sum_{j,k,l} A_{ij} B_{jk} C_{kl} D_{lm}\) |
|
Long chain matrix multiplication |
Multi-Operand Mixed |
\(C_i = \sum_{j,k} A_{ij} B_{jk} C_{ik}\) |
|
Complex mixed operation |
Batch Three-Operand |
Supports arbitrary batch dimensions |
|
Batch chain multiplication |
Repeated Index Operations
Operation Type |
Mathematical Description |
einsum Equation |
Description |
|---|---|---|---|
First Two Indices Repeated |
\(C_j = \sum_{i} A_{iij}\) |
|
Extract specific diagonal |
Multiple Indices Repeated |
\(C_{ij} = \sum_{k,l} A_{iijj}\) |
|
High-dimensional diagonal extraction |
Non-Contiguous Index Repeat |
\(s = \sum_{i,j} A_{ijji}\) |
|
Anti-diagonal summation |
1D Vector Operations
Operation Type |
Mathematical Description |
einsum Equation |
Equivalent Function |
|---|---|---|---|
Vector Dot Product |
\(c = \sum_i a_i b_i\) |
|
|
Vector Outer Product |
\(C_{ij} = a_i b_j\) |
|
|
Matrix-Vector Multiplication |
\(c_i = \sum_j A_{ij} b_j\) |
|
|
Vector-Matrix Multiplication |
\(c_j = \sum_i a_i A_{ij}\) |
|
|
Batch Matrix-Vector Multiplication |
\(C_{bi} = \sum_j A_{bij} b_{bj}\) |
|
|
einsum Usage Examples
Example 1: Matrix Multiplication
import riemann as rm
# Create matrices
A = rm.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2
B = rm.tensor([[7, 8, 9], [10, 11, 12]]) # 2x3
# Matrix multiplication: (3x2) @ (2x3) = (3x3)
C = rm.einsum('ij,jk->ik', A, B)
print("Matrix multiplication result:")
print(C)
# Output:
# tensor([[ 27, 30, 33],
# [ 61, 68, 75],
# [ 95, 106, 117]])
Example 2: Batch Matrix Multiplication
import riemann as rm
# Create batch matrices (2 matrices of 3x4)
A = rm.randn(2, 3, 4)
# Create batch matrices (2 matrices of 4x5)
B = rm.randn(2, 4, 5)
# Batch matrix multiplication
C = rm.einsum('bij,bjk->bik', A, B)
print(f"Batch matrix multiplication result shape: {C.shape}") # (2, 3, 5)
# Use ellipsis to support more batch dimensions
A = rm.randn(2, 3, 4, 5) # 2x3 matrices of 4x5
B = rm.randn(2, 3, 5, 6) # 2x3 matrices of 5x6
C = rm.einsum('...ij,...jk->...ik', A, B)
print(f"General batch multiplication result shape: {C.shape}") # (2, 3, 4, 6)
Example 3: Trace and Diagonal Operations
import riemann as rm
# Create square matrix
A = rm.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Compute trace (sum of diagonal elements)
trace = rm.einsum('ii->', A)
print(f"Matrix trace: {trace}") # 1 + 5 + 9 = 15
# Extract diagonal elements
diag = rm.einsum('ii->i', A)
print(f"Diagonal elements: {diag}") # [1, 5, 9]
# Batch matrix trace
batch_A = rm.randn(4, 3, 3) # 4 matrices of 3x3
batch_trace = rm.einsum('bii->b', batch_A)
print(f"Batch trace shape: {batch_trace.shape}") # (4,)
Example 4: Transpose and Dimension Permutation
import riemann as rm
# Create matrix
A = rm.tensor([[1, 2, 3],
[4, 5, 6]])
# Matrix transpose
A_T = rm.einsum('ij->ji', A)
print("Transpose result:")
print(A_T)
# Output:
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
# High-dimensional transpose
B = rm.randn(2, 3, 4)
B_perm = rm.einsum('ijk->jki', B)
print(f"High-dimensional transpose shape: {B_perm.shape}") # (3, 4, 2)
Example 5: Vector Operations
import riemann as rm
# Create vectors
a = rm.tensor([1, 2, 3])
b = rm.tensor([4, 5, 6])
# Vector dot product
dot_product = rm.einsum('i,i->', a, b)
print(f"Dot product: {dot_product}") # 1*4 + 2*5 + 3*6 = 32
# Vector outer product
outer_product = rm.einsum('i,j->ij', a, b)
print("Outer product result:")
print(outer_product)
# Output:
# tensor([[ 4, 5, 6],
# [ 8, 10, 12],
# [12, 15, 18]])
# Matrix-vector multiplication
A = rm.tensor([[1, 2, 3],
[4, 5, 6]])
c = rm.einsum('ij,j->i', A, a)
print(f"Matrix-vector multiplication: {c}") # [14, 32]
Example 6: Tensor Contraction
import riemann as rm
# Create 3D tensors
A = rm.randn(2, 3, 4)
B = rm.randn(3, 4, 5)
# Tensor contraction: sum over dimensions 1 and 2
C = rm.einsum('ijk,jkl->il', A, B)
print(f"Tensor contraction result shape: {C.shape}") # (2, 5)
# More complex contraction
D = rm.randn(2, 3, 4, 5)
E = rm.randn(3, 4, 5, 6)
F = rm.einsum('ijkl,jklm->im', D, E)
print(f"Complex contraction result shape: {F.shape}") # (2, 6)
Example 7: Hadamard Product and Frobenius Inner Product
import riemann as rm
# Create matrices
A = rm.tensor([[1, 2],
[3, 4]])
B = rm.tensor([[5, 6],
[7, 8]])
# Hadamard product (element-wise multiplication)
hadamard = rm.einsum('ij,ij->ij', A, B)
print("Hadamard product:")
print(hadamard)
# Output:
# tensor([[ 5, 12],
# [21, 32]])
# Frobenius inner product
frobenius = rm.einsum('ij,ij->', A, B)
print(f"Frobenius inner product: {frobenius}") # 5 + 12 + 21 + 32 = 70
Example 8: Multi-Operand Operations
import riemann as rm
# Create multiple matrices
A = rm.randn(3, 4)
B = rm.randn(4, 5)
C = rm.randn(5, 6)
D = rm.randn(6, 7)
# Four-operand chain multiplication
result = rm.einsum('ij,jk,kl,lm->im', A, B, C, D)
print(f"Four-operand chain result shape: {result.shape}") # (3, 7)
# Equivalent to:
# temp1 = rm.matmul(A, B)
# temp2 = rm.matmul(temp1, C)
# result = rm.matmul(temp2, D)
Example 9: einsum with Gradient Tracking
import riemann as rm
# Create tensors requiring gradients
A = rm.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
B = rm.tensor([[5.0, 6.0], [7.0, 8.0]], requires_grad=True)
# Perform einsum operation
C = rm.einsum('ij,jk->ik', A, B)
# Compute loss and backpropagate
loss = C.sum()
loss.backward()
print("Gradient of A:")
print(A.grad)
print("Gradient of B:")
print(B.grad)
Example 10: Implicit Output (Omitting Output Indices)
import riemann as rm
A = rm.tensor([[1, 2], [3, 4]])
B = rm.tensor([[5, 6], [7, 8]])
# Implicit output: omit the part after ->
# Result indices are arranged in alphabetical order
C = rm.einsum('ij,jk', A, B) # Equivalent to 'ij,jk->ik'
print("Implicit output result:")
print(C)
# Batch implicit output
A = rm.randn(2, 3, 4)
B = rm.randn(2, 4, 5)
C = rm.einsum('...ij,...jk', A, B) # Equivalent to '...ij,...jk->...ik'
print(f"Batch implicit output shape: {C.shape}") # (2, 3, 5)
einsum Performance Optimization Tips
Prefer simple equations: For common matrix multiplications, using
rm.matmuldirectly may be more efficientAvoid unnecessary copies: einsum returns views rather than copies whenever possible
Batch operations over loops: Use
...to represent batch dimensions instead of explicit loopsChain operation fusion: Multiple matrix multiplications can be combined into one einsum call, reducing intermediate results
Pre-compiled equations: For repeatedly used equations, einsum automatically caches optimizations
Notes
Index letter limitation: Indices use lowercase letters (a-z), supporting up to 26 different indices
Dimension matching: Dimensions of repeated indices must be consistent
Device consistency: All input tensors must be on the same device
Data types: einsum follows standard type promotion rules
Gradient tracking: Supports automatic differentiation and normal gradient computation