import torch
import torch.nn.functional as F
def gumbel_top_k_sampling(logits, k, temperature=1.0, hard=False, eps=1e-10):
"""
Differentiable Gumbel top-k sampling using the Gumbel-Softmax trick.
Args:
logits (torch.Tensor): Input logits of shape (..., vocab_size)
k (int): Number of top elements to sample
temperature (float): Temperature for Gumbel-Softmax (lower = more discrete)
hard (bool): If True, returns one-hot vectors; if False, returns soft samples
eps (float): Small constant for numerical stability
Returns:
torch.Tensor: Sampled distribution of shape (..., vocab_size)
where exactly k positions are "selected"
"""
# Get the shape for easier manipulation
shape = logits.shape
vocab_size = shape[-1]
# Sample Gumbel noise
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + eps) + eps)
# Add Gumbel noise to logits
gumbel_logits = (logits + gumbel_noise) / temperature
# Get top-k indices
_, top_k_indices = torch.topk(gumbel_logits, k, dim=-1)
if hard:
# Create one-hot encoding for top-k elements
top_k_mask = torch.zeros_like(logits)
top_k_mask.scatter_(-1, top_k_indices, 1.0)
# Use straight-through estimator: hard selection in forward, soft in backward
gumbel_softmax = F.softmax(gumbel_logits, dim=-1)
return top_k_mask - gumbel_softmax.detach() + gumbel_softmax
else:
# Soft sampling: mask out non-top-k elements and renormalize
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(-1, top_k_indices, 0.0)
masked_gumbel_logits = gumbel_logits + mask
return F.softmax(masked_gumbel_logits, dim=-1)
def gumbel_top_k_sampling_v2(logits, k, temperature=1.0, eps=1e-10):
"""
Alternative implementation using continuous relaxation of top-k operation.
This version maintains better gradients by avoiding hard masking.
Args:
logits (torch.Tensor): Input logits of shape (..., vocab_size)
k (int): Number of top elements to sample
temperature (float): Temperature parameter
eps (float): Small constant for numerical stability
Returns:
torch.Tensor: Soft top-k samples
"""
# Sample Gumbel noise
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + eps) + eps)
gumbel_logits = logits + gumbel_noise
# Use continuous relaxation of top-k
# Sort the gumbel logits to find the k-th largest value
sorted_gumbel, _ = torch.sort(gumbel_logits, dim=-1, descending=True)
threshold = sorted_gumbel[..., k-1:k] # k-th largest value
# Create soft mask using sigmoid
soft_mask = torch.sigmoid((gumbel_logits - threshold) / temperature)
# Apply soft mask and normalize
masked_logits = logits * soft_mask
return F.softmax(masked_logits / temperature, dim=-1)
# Example usage and testing
if __name__ == "__main__":
# Test the implementation
torch.manual_seed(42)
# Example logits
batch_size, seq_len, vocab_size = 2, 3, 10
logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True)
k = 3
print("Original logits shape:", logits.shape)
print("Sampling top-{} elements".format(k))
# Test hard sampling
hard_samples = gumbel_top_k_sampling(logits, k, temperature=0.1, hard=True)
print("\nHard samples (should have exactly {} non-zero elements per row):".format(k))
print("Non-zero counts per sample:", (hard_samples > 0.01).sum(dim=-1))
# Test soft sampling
soft_samples = gumbel_top_k_sampling(logits, k, temperature=1.0, hard=False)
print("\nSoft samples (continuous relaxation):")
print("Sample sums (should be close to 1.0):", soft_samples.sum(dim=-1))
# Test alternative implementation
soft_samples_v2 = gumbel_top_k_sampling_v2(logits, k, temperature=1.0)
print("\nAlternative soft samples:")
print("Sample sums:", soft_samples_v2.sum(dim=-1))
# Test gradient flow
loss = hard_samples.sum()
loss.backward()
print("\nGradient flow test - logits.grad is not None:", logits.grad is not None)
print("Gradient norm:", logits.grad.norm().item())
class GumbelTopKSampler(torch.nn.Module):
"""
PyTorch module wrapper for Gumbel top-k sampling.
"""
def __init__(self, k, temperature=1.0, hard=False):
super().__init__()
self.k = k
self.temperature = temperature
self.hard = hard
def forward(self, logits):
return gumbel_top_k_sampling(
logits, self.k, self.temperature, self.hard
)
# Example integration with a simple model
class SimpleModel(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, vocab_size, k):
super().__init__()
self.linear = torch.nn.Linear(input_dim, hidden_dim)
self.output = torch.nn.Linear(hidden_dim, vocab_size)
self.sampler = GumbelTopKSampler(k, temperature=0.5, hard=True)
def forward(self, x):
h = torch.relu(self.linear(x))
logits = self.output(h)
samples = self.sampler(logits)
return samples, logits