How to create 3D indexable embeddings in PyTorch?
Nathan RanchinPyTorch's nn.Embedding module is useful but it only allows creating embeddings of size (n, d) where n is the number of embeddings and d is the dimension of each embedding.
However, sometimes it's necessary to create embeddings of size (b, n, d) where b is the batch size, n is the number of embeddings, and d is the dimension of each embedding.
Naive Solution
To achieve this, we could use a simple for loop:
import torch
import torch.nn.functional as F
# Create embedding weights with shape (n, d)
embedding_weight = torch.randn(n, d)
# Create random input indices with shape (b, n)
input = torch.randint(0, n, (b, n))
# Initialize result tensor
result = torch.zeros(b, n, d)
# Iterate through each batch
for i in range(b):
# Apply embedding lookup for each batch separately
result[i] = F.embedding(input[i], embedding_weight)
This method is very slow and not suitable for training models, as it processes each batch sequentially.
Optimized Solution
To accelerate the process, we can use offsets to handle batches efficiently:
import torch
import torch.nn.functional as F
# Create embedding weights with shape (b, n, d)
embedding_weight = torch.randn(b, n, d)
# Create random input indices with shape (b, n)
input = torch.randint(0, n, (b, n))
# Create batch offsets to make indices unique across batches
# This transforms indices to be within unique ranges for each batch
offsets = torch.arange(
input.size(0), device=input.device, dtype=torch.long
).unsqueeze(-1).expand_as(input) * n
# Apply embedding lookup with offset indices and reshaped weights
# The offsets ensure each batch accesses its own section of the embedding weights
result = F.embedding(input + offsets, embedding_weight.view(b * n, d))
This method is much faster as it performs the embedding lookup in a single operation, making it suitable for model training with batch processing. The key insight is using offsets to create unique indices across batches, allowing a single embedding operation to handle all batches simultaneously.