Content is user-generated and unverified.
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from functools import lru_cache from typing import List, Optional, Tuple @lru_cache(maxsize=None) def is_prime(n: int) -> bool: """Check if a number is prime (cached).""" if n < 2: return False for i in range(2, int(n**0.5) + 1): if n % i == 0: return False return True @lru_cache(maxsize=None) def get_primes_up_to(n: int) -> List[int]: """Get all prime numbers up to n (cached).""" return [i for i in range(2, n + 1) if is_prime(i)] @lru_cache(maxsize=None) def get_prime_set(n: int) -> set: """Get prime numbers up to n as a set for O(1) lookup.""" return set(get_primes_up_to(n)) class PrimeSparsityLayer(nn.Module): """ A sparsity layer that enforces prime-based connection patterns. This layer can operate in several modes: 1. Prime-only connections: Only allow connections at prime-indexed positions 2. Prime-factorized groups: Group neurons based on prime factorization 3. Prime-residue patterns: Use modular arithmetic with prime bases """ def __init__(self, input_dim: int, output_dim: int, sparsity_mode: str = 'prime_only', prime_base: Optional[int] = None, sparsity_ratio: float = 0.5): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.sparsity_mode = sparsity_mode self.sparsity_ratio = sparsity_ratio # Initialize full weight matrix with proper scaling self.weight = nn.Parameter(torch.empty(output_dim, input_dim)) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.bias = nn.Parameter(torch.zeros(output_dim)) # Generate prime-based mask if prime_base is None: # Use largest prime less than min(input_dim, output_dim) prime_base = max([p for p in get_primes_up_to(min(input_dim, output_dim)) if p > 0] or [2]) self.prime_base = prime_base self.mask = self._generate_prime_mask() # Make mask non-trainable self.register_buffer('sparsity_mask', self.mask) def _generate_prime_mask(self) -> torch.Tensor: """Generate sparsity mask based on prime patterns.""" mask = torch.ones(self.output_dim, self.input_dim) if self.sparsity_mode == 'prime_only': # Only allow connections at prime-indexed positions prime_set = get_prime_set(max(self.input_dim, self.output_dim)) mask.fill_(0) for i in range(self.output_dim): for j in range(self.input_dim): if i in prime_set or j in prime_set: mask[i, j] = 1 elif self.sparsity_mode == 'prime_residue': # Use modular arithmetic with prime base mask.fill_(0) for i in range(self.output_dim): for j in range(self.input_dim): if (i * j) % self.prime_base == 0: mask[i, j] = 1 elif self.sparsity_mode == 'prime_factorized': # Group connections based on prime factorization mask.fill_(0) def prime_factors(n): factors = [] d = 2 while d * d <= n: while n % d == 0: factors.append(d) n //= d d += 1 if n > 1: factors.append(n) return factors for i in range(self.output_dim): for j in range(self.input_dim): i_factors = set(prime_factors(i + 1)) # +1 to avoid 0 j_factors = set(prime_factors(j + 1)) # Connect if they share prime factors if i_factors & j_factors: mask[i, j] = 1 elif self.sparsity_mode == 'anti_prime': # Avoid prime-based connections (control condition) prime_set = get_prime_set(max(self.input_dim, self.output_dim)) for i in range(self.output_dim): for j in range(self.input_dim): if i not in prime_set and j not in prime_set: mask[i, j] = 1 else: mask[i, j] = 0 # Ensure we maintain desired sparsity ratio with unbiased sampling total_possible = self.output_dim * self.input_dim current_connections = mask.sum().item() desired_connections = int(total_possible * self.sparsity_ratio) if current_connections > desired_connections: # Get 2D indices of active connections for unbiased sampling active_indices = torch.nonzero(mask, as_tuple=False) num_to_remove = current_connections - desired_connections # Random permutation for unbiased selection perm = torch.randperm(len(active_indices))[:num_to_remove] indices_to_remove = active_indices[perm] # Remove selected connections mask[indices_to_remove[:, 0], indices_to_remove[:, 1]] = 0 return mask def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with prime-based sparsity.""" # Apply sparsity mask to weights sparse_weight = self.weight * self.sparsity_mask return F.linear(x, sparse_weight, self.bias) def get_sparsity_stats(self) -> dict: """Get statistics about the sparsity pattern.""" total_params = self.output_dim * self.input_dim active_params = self.sparsity_mask.sum().item() return { 'total_parameters': total_params, 'active_parameters': active_params, 'sparsity_ratio': 1 - (active_params / total_params), 'density': active_params / total_params, 'prime_base': self.prime_base, 'sparsity_mode': self.sparsity_mode, 'effective_connections': active_params } def visualize_mask(self, figsize=(8, 6)): """Create a visualization of the sparsity mask.""" try: import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=figsize) mask_np = self.sparsity_mask.detach().cpu().numpy() im = ax.imshow(mask_np, cmap='Blues', aspect='auto') ax.set_title(f'Prime Sparsity Pattern ({self.sparsity_mode})') ax.set_xlabel('Input Dimension') ax.set_ylabel('Output Dimension') # Add colorbar plt.colorbar(im, ax=ax) # Add statistics as text stats = self.get_sparsity_stats() stats_text = f"Density: {stats['density']:.3f}\nActive: {stats['active_parameters']}" ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white')) plt.tight_layout() return fig except ImportError: print("Matplotlib not available for visualization") return None def enforce_sparsity_gradients(self): """Enforce sparsity by zeroing gradients of masked weights.""" if self.weight.grad is not None: self.weight.grad.data.mul_(self.sparsity_mask) class PrimeExperimentHarness: """ Comprehensive experimental harness for testing prime-based sparsity hypotheses. """ def __init__(self, device='cpu'): self.device = device self.results = {} def grokking_task_dataset(self, operation='add', modulus=97, size=1000): """Generate modular arithmetic dataset for grokking experiments.""" # Generate all possible pairs a = torch.randint(0, modulus, (size,)) b = torch.randint(0, modulus, (size,)) if operation == 'add': targets = (a + b) % modulus elif operation == 'multiply': targets = (a * b) % modulus elif operation == 'subtract': targets = (a - b) % modulus else: raise ValueError(f"Unknown operation: {operation}") # Create input features (one-hot encoding) inputs = torch.zeros(size, 2 * modulus) inputs[range(size), a] = 1 inputs[range(size), modulus + b] = 1 return inputs.to(self.device), targets.to(self.device) def run_width_experiment(self, task_params={'operation': 'add', 'modulus': 97, 'size': 2000}, width_configs=['prime', 'power_of_2', 'composite'], epochs=5000, lr=1e-3): """Compare different width configurations.""" # Generate dataset inputs, targets = self.grokking_task_dataset(**task_params) train_size = int(0.8 * len(inputs)) train_inputs, test_inputs = inputs[:train_size], inputs[train_size:] train_targets, test_targets = targets[:train_size], targets[train_size:] width_mappings = { 'prime': [127, 251, 509], 'power_of_2': [128, 256, 512], 'composite': [126, 252, 504] # 2×3×21, 2²×3²×7, 2³×3²×7 } results = {} for config_name in width_configs: print(f"\nTesting {config_name} width configuration...") # Create model model = PrimeWidthNetwork( input_dim=inputs.shape[1], num_classes=task_params['modulus'], use_prime_widths=(config_name == 'prime'), use_prime_sparsity=False, # Test width effect only hidden_dims=width_mappings[config_name] ).to(self.device) # Train and track metrics train_losses, test_losses, test_accs = self._train_model( model, train_inputs, train_targets, test_inputs, test_targets, epochs=epochs, lr=lr ) results[config_name] = { 'train_losses': train_losses, 'test_losses': test_losses, 'test_accuracies': test_accs, 'grokking_onset': self._detect_grokking_onset(test_accs), 'final_accuracy': test_accs[-1] } self.results['width_experiment'] = results return results def run_sparsity_experiment(self, task_params={'operation': 'add', 'modulus': 97, 'size': 2000}, sparsity_modes=['prime_only', 'prime_residue', 'anti_prime', 'random'], epochs=5000, lr=1e-3): """Compare different sparsity patterns.""" inputs, targets = self.grokking_task_dataset(**task_params) train_size = int(0.8 * len(inputs)) train_inputs, test_inputs = inputs[:train_size], inputs[train_size:] train_targets, test_targets = targets[:train_size], targets[train_size:] results = {} for mode in sparsity_modes: print(f"\nTesting {mode} sparsity mode...") if mode == 'random': # Standard network with dropout model = nn.Sequential( nn.Linear(inputs.shape[1], 127), nn.ReLU(), nn.Dropout(0.5), nn.Linear(127, 251), nn.ReLU(), nn.Dropout(0.5), nn.Linear(251, task_params['modulus']) ).to(self.device) else: # Prime sparsity network model = nn.Sequential( PrimeSparsityLayer(inputs.shape[1], 127, sparsity_mode=mode), nn.ReLU(), PrimeSparsityLayer(127, 251, sparsity_mode=mode), nn.ReLU(), nn.Linear(251, task_params['modulus']) ).to(self.device) train_losses, test_losses, test_accs = self._train_model( model, train_inputs, train_targets, test_inputs, test_targets, epochs=epochs, lr=lr ) results[mode] = { 'train_losses': train_losses, 'test_losses': test_losses, 'test_accuracies': test_accs, 'grokking_onset': self._detect_grokking_onset(test_accs), 'final_accuracy': test_accs[-1] } self.results['sparsity_experiment'] = results return results def _train_model(self, model, train_inputs, train_targets, test_inputs, test_targets, epochs=5000, lr=1e-3, log_interval=500): """Train a model and track metrics.""" optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() train_losses = [] test_losses = [] test_accs = [] for epoch in range(epochs): # Training model.train() optimizer.zero_grad() outputs = model(train_inputs) train_loss = criterion(outputs, train_targets) train_loss.backward() # Enforce sparsity in gradients if using prime sparsity layers for module in model.modules(): if isinstance(module, PrimeSparsityLayer): module.enforce_sparsity_gradients() optimizer.step() # Evaluation if epoch % log_interval == 0 or epoch == epochs - 1: model.eval() with torch.no_grad(): test_outputs = model(test_inputs) test_loss = criterion(test_outputs, test_targets) test_acc = (test_outputs.argmax(dim=1) == test_targets).float().mean() train_losses.append(train_loss.item()) test_losses.append(test_loss.item()) test_accs.append(test_acc.item()) if epoch % (log_interval * 10) == 0: print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, " f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}") return train_losses, test_losses, test_accs def _detect_grokking_onset(self, test_accuracies, threshold=0.95, window=10): """Detect when grokking occurs (sudden jump to high accuracy).""" test_accuracies = np.array(test_accuracies) # Find first point where accuracy exceeds threshold and stays high for i in range(len(test_accuracies) - window): if np.all(test_accuracies[i:i+window] > threshold): return i return len(test_accuracies) # No grokking detected def plot_results(self, experiment_type='width_experiment'): """Plot experimental results.""" try: import matplotlib.pyplot as plt if experiment_type not in self.results: print(f"No results found for {experiment_type}") return results = self.results[experiment_type] fig, axes = plt.subplots(2, 2, figsize=(15, 10)) # Plot 1: Training loss curves for name, data in results.items(): axes[0, 0].plot(data['train_losses'], label=name) axes[0, 0].set_title('Training Loss') axes[0, 0].set_xlabel('Epoch (×log_interval)') axes[0, 0].set_ylabel('Loss') axes[0, 0].legend() axes[0, 0].set_yscale('log') # Plot 2: Test accuracy curves for name, data in results.items(): axes[0, 1].plot(data['test_accuracies'], label=name) axes[0, 1].set_title('Test Accuracy') axes[0, 1].set_xlabel('Epoch (×log_interval)') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].legend() # Plot 3: Grokking onset comparison onset_times = [results[name]['grokking_onset'] for name in results.keys()] axes[1, 0].bar(results.keys(), onset_times) axes[1, 0].set_title('Grokking Onset Time') axes[1, 0].set_ylabel('Epochs (×log_interval)') # Plot 4: Final accuracy comparison final_accs = [results[name]['final_accuracy'] for name in results.keys()] axes[1, 1].bar(results.keys(), final_accs) axes[1, 1].set_title('Final Test Accuracy') axes[1, 1].set_ylabel('Accuracy') plt.tight_layout() return fig except ImportError: print("Matplotlib not available for plotting") return None """ A test network with prime-based width choices and prime sparsity layers. """ def __init__(self, input_dim: int, num_classes: int, use_prime_widths: bool = True, use_prime_sparsity: bool = True, sparsity_mode: str = 'prime_only'): super().__init__() # Choose layer widths based on prime hypothesis if use_prime_widths: # Use prime numbers for hidden dimensions hidden_dims = [127, 251, 509] # Prime numbers else: # Use powers of 2 (control condition) hidden_dims = [128, 256, 512] layers = [] prev_dim = input_dim for i, dim in enumerate(hidden_dims): if use_prime_sparsity: # Use prime sparsity layer layer = PrimeSparsityLayer(prev_dim, dim, sparsity_mode=sparsity_mode) else: # Use standard linear layer layer = nn.Linear(prev_dim, dim) layers.extend([layer, nn.ReLU()]) prev_dim = dim # Final layer layers.append(nn.Linear(prev_dim, num_classes)) self.network = nn.Sequential(*layers) def forward(self, x): return self.network(x) # Example usage and testing if __name__ == "__main__": # Test the prime sparsity layer layer = PrimeSparsityLayer(100, 127, sparsity_mode='prime_only') print("Sparsity stats:", layer.get_sparsity_stats()) # Test forward pass x = torch.randn(32, 100) output = layer(x) print(f"Input shape: {x.shape}, Output shape: {output.shape}") # Test different sparsity modes for mode in ['prime_only', 'prime_residue', 'prime_factorized', 'anti_prime']: layer = PrimeSparsityLayer(50, 50, sparsity_mode=mode, sparsity_ratio=0.5) stats = layer.get_sparsity_stats() print(f"{mode}: {stats['active_parameters']}/{stats['total_parameters']} active parameters")
Content is user-generated and unverified.
    Prime-Aware Sparsity Layer | Claude