import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from functools import lru_cache
from typing import List, Optional, Tuple
@lru_cache(maxsize=None)
def is_prime(n: int) -> bool:
"""Check if a number is prime (cached)."""
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
@lru_cache(maxsize=None)
def get_primes_up_to(n: int) -> List[int]:
"""Get all prime numbers up to n (cached)."""
return [i for i in range(2, n + 1) if is_prime(i)]
@lru_cache(maxsize=None)
def get_prime_set(n: int) -> set:
"""Get prime numbers up to n as a set for O(1) lookup."""
return set(get_primes_up_to(n))
class PrimeSparsityLayer(nn.Module):
"""
A sparsity layer that enforces prime-based connection patterns.
This layer can operate in several modes:
1. Prime-only connections: Only allow connections at prime-indexed positions
2. Prime-factorized groups: Group neurons based on prime factorization
3. Prime-residue patterns: Use modular arithmetic with prime bases
"""
def __init__(self,
input_dim: int,
output_dim: int,
sparsity_mode: str = 'prime_only',
prime_base: Optional[int] = None,
sparsity_ratio: float = 0.5):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.sparsity_mode = sparsity_mode
self.sparsity_ratio = sparsity_ratio
# Initialize full weight matrix with proper scaling
self.weight = nn.Parameter(torch.empty(output_dim, input_dim))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
self.bias = nn.Parameter(torch.zeros(output_dim))
# Generate prime-based mask
if prime_base is None:
# Use largest prime less than min(input_dim, output_dim)
prime_base = max([p for p in get_primes_up_to(min(input_dim, output_dim)) if p > 0] or [2])
self.prime_base = prime_base
self.mask = self._generate_prime_mask()
# Make mask non-trainable
self.register_buffer('sparsity_mask', self.mask)
def _generate_prime_mask(self) -> torch.Tensor:
"""Generate sparsity mask based on prime patterns."""
mask = torch.ones(self.output_dim, self.input_dim)
if self.sparsity_mode == 'prime_only':
# Only allow connections at prime-indexed positions
prime_set = get_prime_set(max(self.input_dim, self.output_dim))
mask.fill_(0)
for i in range(self.output_dim):
for j in range(self.input_dim):
if i in prime_set or j in prime_set:
mask[i, j] = 1
elif self.sparsity_mode == 'prime_residue':
# Use modular arithmetic with prime base
mask.fill_(0)
for i in range(self.output_dim):
for j in range(self.input_dim):
if (i * j) % self.prime_base == 0:
mask[i, j] = 1
elif self.sparsity_mode == 'prime_factorized':
# Group connections based on prime factorization
mask.fill_(0)
def prime_factors(n):
factors = []
d = 2
while d * d <= n:
while n % d == 0:
factors.append(d)
n //= d
d += 1
if n > 1:
factors.append(n)
return factors
for i in range(self.output_dim):
for j in range(self.input_dim):
i_factors = set(prime_factors(i + 1)) # +1 to avoid 0
j_factors = set(prime_factors(j + 1))
# Connect if they share prime factors
if i_factors & j_factors:
mask[i, j] = 1
elif self.sparsity_mode == 'anti_prime':
# Avoid prime-based connections (control condition)
prime_set = get_prime_set(max(self.input_dim, self.output_dim))
for i in range(self.output_dim):
for j in range(self.input_dim):
if i not in prime_set and j not in prime_set:
mask[i, j] = 1
else:
mask[i, j] = 0
# Ensure we maintain desired sparsity ratio with unbiased sampling
total_possible = self.output_dim * self.input_dim
current_connections = mask.sum().item()
desired_connections = int(total_possible * self.sparsity_ratio)
if current_connections > desired_connections:
# Get 2D indices of active connections for unbiased sampling
active_indices = torch.nonzero(mask, as_tuple=False)
num_to_remove = current_connections - desired_connections
# Random permutation for unbiased selection
perm = torch.randperm(len(active_indices))[:num_to_remove]
indices_to_remove = active_indices[perm]
# Remove selected connections
mask[indices_to_remove[:, 0], indices_to_remove[:, 1]] = 0
return mask
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with prime-based sparsity."""
# Apply sparsity mask to weights
sparse_weight = self.weight * self.sparsity_mask
return F.linear(x, sparse_weight, self.bias)
def get_sparsity_stats(self) -> dict:
"""Get statistics about the sparsity pattern."""
total_params = self.output_dim * self.input_dim
active_params = self.sparsity_mask.sum().item()
return {
'total_parameters': total_params,
'active_parameters': active_params,
'sparsity_ratio': 1 - (active_params / total_params),
'density': active_params / total_params,
'prime_base': self.prime_base,
'sparsity_mode': self.sparsity_mode,
'effective_connections': active_params
}
def visualize_mask(self, figsize=(8, 6)):
"""Create a visualization of the sparsity mask."""
try:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=figsize)
mask_np = self.sparsity_mask.detach().cpu().numpy()
im = ax.imshow(mask_np, cmap='Blues', aspect='auto')
ax.set_title(f'Prime Sparsity Pattern ({self.sparsity_mode})')
ax.set_xlabel('Input Dimension')
ax.set_ylabel('Output Dimension')
# Add colorbar
plt.colorbar(im, ax=ax)
# Add statistics as text
stats = self.get_sparsity_stats()
stats_text = f"Density: {stats['density']:.3f}\nActive: {stats['active_parameters']}"
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white'))
plt.tight_layout()
return fig
except ImportError:
print("Matplotlib not available for visualization")
return None
def enforce_sparsity_gradients(self):
"""Enforce sparsity by zeroing gradients of masked weights."""
if self.weight.grad is not None:
self.weight.grad.data.mul_(self.sparsity_mask)
class PrimeExperimentHarness:
"""
Comprehensive experimental harness for testing prime-based sparsity hypotheses.
"""
def __init__(self, device='cpu'):
self.device = device
self.results = {}
def grokking_task_dataset(self, operation='add', modulus=97, size=1000):
"""Generate modular arithmetic dataset for grokking experiments."""
# Generate all possible pairs
a = torch.randint(0, modulus, (size,))
b = torch.randint(0, modulus, (size,))
if operation == 'add':
targets = (a + b) % modulus
elif operation == 'multiply':
targets = (a * b) % modulus
elif operation == 'subtract':
targets = (a - b) % modulus
else:
raise ValueError(f"Unknown operation: {operation}")
# Create input features (one-hot encoding)
inputs = torch.zeros(size, 2 * modulus)
inputs[range(size), a] = 1
inputs[range(size), modulus + b] = 1
return inputs.to(self.device), targets.to(self.device)
def run_width_experiment(self,
task_params={'operation': 'add', 'modulus': 97, 'size': 2000},
width_configs=['prime', 'power_of_2', 'composite'],
epochs=5000,
lr=1e-3):
"""Compare different width configurations."""
# Generate dataset
inputs, targets = self.grokking_task_dataset(**task_params)
train_size = int(0.8 * len(inputs))
train_inputs, test_inputs = inputs[:train_size], inputs[train_size:]
train_targets, test_targets = targets[:train_size], targets[train_size:]
width_mappings = {
'prime': [127, 251, 509],
'power_of_2': [128, 256, 512],
'composite': [126, 252, 504] # 2×3×21, 2²×3²×7, 2³×3²×7
}
results = {}
for config_name in width_configs:
print(f"\nTesting {config_name} width configuration...")
# Create model
model = PrimeWidthNetwork(
input_dim=inputs.shape[1],
num_classes=task_params['modulus'],
use_prime_widths=(config_name == 'prime'),
use_prime_sparsity=False, # Test width effect only
hidden_dims=width_mappings[config_name]
).to(self.device)
# Train and track metrics
train_losses, test_losses, test_accs = self._train_model(
model, train_inputs, train_targets, test_inputs, test_targets,
epochs=epochs, lr=lr
)
results[config_name] = {
'train_losses': train_losses,
'test_losses': test_losses,
'test_accuracies': test_accs,
'grokking_onset': self._detect_grokking_onset(test_accs),
'final_accuracy': test_accs[-1]
}
self.results['width_experiment'] = results
return results
def run_sparsity_experiment(self,
task_params={'operation': 'add', 'modulus': 97, 'size': 2000},
sparsity_modes=['prime_only', 'prime_residue', 'anti_prime', 'random'],
epochs=5000,
lr=1e-3):
"""Compare different sparsity patterns."""
inputs, targets = self.grokking_task_dataset(**task_params)
train_size = int(0.8 * len(inputs))
train_inputs, test_inputs = inputs[:train_size], inputs[train_size:]
train_targets, test_targets = targets[:train_size], targets[train_size:]
results = {}
for mode in sparsity_modes:
print(f"\nTesting {mode} sparsity mode...")
if mode == 'random':
# Standard network with dropout
model = nn.Sequential(
nn.Linear(inputs.shape[1], 127),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(127, 251),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(251, task_params['modulus'])
).to(self.device)
else:
# Prime sparsity network
model = nn.Sequential(
PrimeSparsityLayer(inputs.shape[1], 127, sparsity_mode=mode),
nn.ReLU(),
PrimeSparsityLayer(127, 251, sparsity_mode=mode),
nn.ReLU(),
nn.Linear(251, task_params['modulus'])
).to(self.device)
train_losses, test_losses, test_accs = self._train_model(
model, train_inputs, train_targets, test_inputs, test_targets,
epochs=epochs, lr=lr
)
results[mode] = {
'train_losses': train_losses,
'test_losses': test_losses,
'test_accuracies': test_accs,
'grokking_onset': self._detect_grokking_onset(test_accs),
'final_accuracy': test_accs[-1]
}
self.results['sparsity_experiment'] = results
return results
def _train_model(self, model, train_inputs, train_targets, test_inputs, test_targets,
epochs=5000, lr=1e-3, log_interval=500):
"""Train a model and track metrics."""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
train_losses = []
test_losses = []
test_accs = []
for epoch in range(epochs):
# Training
model.train()
optimizer.zero_grad()
outputs = model(train_inputs)
train_loss = criterion(outputs, train_targets)
train_loss.backward()
# Enforce sparsity in gradients if using prime sparsity layers
for module in model.modules():
if isinstance(module, PrimeSparsityLayer):
module.enforce_sparsity_gradients()
optimizer.step()
# Evaluation
if epoch % log_interval == 0 or epoch == epochs - 1:
model.eval()
with torch.no_grad():
test_outputs = model(test_inputs)
test_loss = criterion(test_outputs, test_targets)
test_acc = (test_outputs.argmax(dim=1) == test_targets).float().mean()
train_losses.append(train_loss.item())
test_losses.append(test_loss.item())
test_accs.append(test_acc.item())
if epoch % (log_interval * 10) == 0:
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, "
f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
return train_losses, test_losses, test_accs
def _detect_grokking_onset(self, test_accuracies, threshold=0.95, window=10):
"""Detect when grokking occurs (sudden jump to high accuracy)."""
test_accuracies = np.array(test_accuracies)
# Find first point where accuracy exceeds threshold and stays high
for i in range(len(test_accuracies) - window):
if np.all(test_accuracies[i:i+window] > threshold):
return i
return len(test_accuracies) # No grokking detected
def plot_results(self, experiment_type='width_experiment'):
"""Plot experimental results."""
try:
import matplotlib.pyplot as plt
if experiment_type not in self.results:
print(f"No results found for {experiment_type}")
return
results = self.results[experiment_type]
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Training loss curves
for name, data in results.items():
axes[0, 0].plot(data['train_losses'], label=name)
axes[0, 0].set_title('Training Loss')
axes[0, 0].set_xlabel('Epoch (×log_interval)')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].set_yscale('log')
# Plot 2: Test accuracy curves
for name, data in results.items():
axes[0, 1].plot(data['test_accuracies'], label=name)
axes[0, 1].set_title('Test Accuracy')
axes[0, 1].set_xlabel('Epoch (×log_interval)')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
# Plot 3: Grokking onset comparison
onset_times = [results[name]['grokking_onset'] for name in results.keys()]
axes[1, 0].bar(results.keys(), onset_times)
axes[1, 0].set_title('Grokking Onset Time')
axes[1, 0].set_ylabel('Epochs (×log_interval)')
# Plot 4: Final accuracy comparison
final_accs = [results[name]['final_accuracy'] for name in results.keys()]
axes[1, 1].bar(results.keys(), final_accs)
axes[1, 1].set_title('Final Test Accuracy')
axes[1, 1].set_ylabel('Accuracy')
plt.tight_layout()
return fig
except ImportError:
print("Matplotlib not available for plotting")
return None
"""
A test network with prime-based width choices and prime sparsity layers.
"""
def __init__(self,
input_dim: int,
num_classes: int,
use_prime_widths: bool = True,
use_prime_sparsity: bool = True,
sparsity_mode: str = 'prime_only'):
super().__init__()
# Choose layer widths based on prime hypothesis
if use_prime_widths:
# Use prime numbers for hidden dimensions
hidden_dims = [127, 251, 509] # Prime numbers
else:
# Use powers of 2 (control condition)
hidden_dims = [128, 256, 512]
layers = []
prev_dim = input_dim
for i, dim in enumerate(hidden_dims):
if use_prime_sparsity:
# Use prime sparsity layer
layer = PrimeSparsityLayer(prev_dim, dim, sparsity_mode=sparsity_mode)
else:
# Use standard linear layer
layer = nn.Linear(prev_dim, dim)
layers.extend([layer, nn.ReLU()])
prev_dim = dim
# Final layer
layers.append(nn.Linear(prev_dim, num_classes))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Example usage and testing
if __name__ == "__main__":
# Test the prime sparsity layer
layer = PrimeSparsityLayer(100, 127, sparsity_mode='prime_only')
print("Sparsity stats:", layer.get_sparsity_stats())
# Test forward pass
x = torch.randn(32, 100)
output = layer(x)
print(f"Input shape: {x.shape}, Output shape: {output.shape}")
# Test different sparsity modes
for mode in ['prime_only', 'prime_residue', 'prime_factorized', 'anti_prime']:
layer = PrimeSparsityLayer(50, 50, sparsity_mode=mode, sparsity_ratio=0.5)
stats = layer.get_sparsity_stats()
print(f"{mode}: {stats['active_parameters']}/{stats['total_parameters']} active parameters")