Content is user-generated and unverified.
import torch import torch.nn.functional as F def gumbel_top_k_sampling(logits, k, temperature=1.0, hard=False, eps=1e-10): """ Differentiable Gumbel top-k sampling using the Gumbel-Softmax trick. Args: logits (torch.Tensor): Input logits of shape (..., vocab_size) k (int): Number of top elements to sample temperature (float): Temperature for Gumbel-Softmax (lower = more discrete) hard (bool): If True, returns one-hot vectors; if False, returns soft samples eps (float): Small constant for numerical stability Returns: torch.Tensor: Sampled distribution of shape (..., vocab_size) where exactly k positions are "selected" """ # Get the shape for easier manipulation shape = logits.shape vocab_size = shape[-1] # Sample Gumbel noise gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + eps) + eps) # Add Gumbel noise to logits gumbel_logits = (logits + gumbel_noise) / temperature # Get top-k indices _, top_k_indices = torch.topk(gumbel_logits, k, dim=-1) if hard: # Create one-hot encoding for top-k elements top_k_mask = torch.zeros_like(logits) top_k_mask.scatter_(-1, top_k_indices, 1.0) # Use straight-through estimator: hard selection in forward, soft in backward gumbel_softmax = F.softmax(gumbel_logits, dim=-1) return top_k_mask - gumbel_softmax.detach() + gumbel_softmax else: # Soft sampling: mask out non-top-k elements and renormalize mask = torch.full_like(logits, float('-inf')) mask.scatter_(-1, top_k_indices, 0.0) masked_gumbel_logits = gumbel_logits + mask return F.softmax(masked_gumbel_logits, dim=-1) def gumbel_top_k_sampling_v2(logits, k, temperature=1.0, eps=1e-10): """ Alternative implementation using continuous relaxation of top-k operation. This version maintains better gradients by avoiding hard masking. Args: logits (torch.Tensor): Input logits of shape (..., vocab_size) k (int): Number of top elements to sample temperature (float): Temperature parameter eps (float): Small constant for numerical stability Returns: torch.Tensor: Soft top-k samples """ # Sample Gumbel noise gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + eps) + eps) gumbel_logits = logits + gumbel_noise # Use continuous relaxation of top-k # Sort the gumbel logits to find the k-th largest value sorted_gumbel, _ = torch.sort(gumbel_logits, dim=-1, descending=True) threshold = sorted_gumbel[..., k-1:k] # k-th largest value # Create soft mask using sigmoid soft_mask = torch.sigmoid((gumbel_logits - threshold) / temperature) # Apply soft mask and normalize masked_logits = logits * soft_mask return F.softmax(masked_logits / temperature, dim=-1) # Example usage and testing if __name__ == "__main__": # Test the implementation torch.manual_seed(42) # Example logits batch_size, seq_len, vocab_size = 2, 3, 10 logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) k = 3 print("Original logits shape:", logits.shape) print("Sampling top-{} elements".format(k)) # Test hard sampling hard_samples = gumbel_top_k_sampling(logits, k, temperature=0.1, hard=True) print("\nHard samples (should have exactly {} non-zero elements per row):".format(k)) print("Non-zero counts per sample:", (hard_samples > 0.01).sum(dim=-1)) # Test soft sampling soft_samples = gumbel_top_k_sampling(logits, k, temperature=1.0, hard=False) print("\nSoft samples (continuous relaxation):") print("Sample sums (should be close to 1.0):", soft_samples.sum(dim=-1)) # Test alternative implementation soft_samples_v2 = gumbel_top_k_sampling_v2(logits, k, temperature=1.0) print("\nAlternative soft samples:") print("Sample sums:", soft_samples_v2.sum(dim=-1)) # Test gradient flow loss = hard_samples.sum() loss.backward() print("\nGradient flow test - logits.grad is not None:", logits.grad is not None) print("Gradient norm:", logits.grad.norm().item()) class GumbelTopKSampler(torch.nn.Module): """ PyTorch module wrapper for Gumbel top-k sampling. """ def __init__(self, k, temperature=1.0, hard=False): super().__init__() self.k = k self.temperature = temperature self.hard = hard def forward(self, logits): return gumbel_top_k_sampling( logits, self.k, self.temperature, self.hard ) # Example integration with a simple model class SimpleModel(torch.nn.Module): def __init__(self, input_dim, hidden_dim, vocab_size, k): super().__init__() self.linear = torch.nn.Linear(input_dim, hidden_dim) self.output = torch.nn.Linear(hidden_dim, vocab_size) self.sampler = GumbelTopKSampler(k, temperature=0.5, hard=True) def forward(self, x): h = torch.relu(self.linear(x)) logits = self.output(h) samples = self.sampler(logits) return samples, logits
Content is user-generated and unverified.
    Differentiable Gumbel Top-K Sampling | Claude