PyTorch: The Training Loop
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# Define model
class MLP(nn.Module):
def __init__(self, in_dim, hidden, out_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden, out_dim)
)
def forward(self, x): return self.net(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP(784, 256, 10).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
criterion = nn.CrossEntropyLoss()
# Training loop (canonical pattern)
for epoch in range(num_epochs):
model.train()
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimizer.zero_grad()
pred = model(X_batch)
loss = criterion(pred, y_batch)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradient clipping
optimizer.step()
scheduler.step()
model.eval()
with torch.no_grad(): # disable gradient computation
# compute val loss and accuracy
pass
Common Bugs: Forgetting model.eval() before inference (dropout stays active!), forgetting zero_grad() (gradients accumulate!), using .detach() incorrectly. Always call these in the right order.
Attention Mechanism & Transformers
Scaled Dot-Product Attention
Attention(Q,K,V) = softmax(QKᵀ/√d_k)·V
- Q (Query): What am I looking for?
- K (Key): What do I contain?
- V (Value): What do I return?
- √d_k prevents softmax saturation (vanishing gradients)
- O(n²) memory — quadratic in sequence length
Multi-Head Attention
MultiHead = Concat(head₁,...,headₕ)·W_O
headᵢ = Attention(QWᵢ_Q, KWᵢ_K, VWᵢ_V)
Different heads learn different types of relationships: syntax, coreference, semantics. h=8 or 16 typically.
Transformer Block
Input → LayerNorm → Multi-Head Attention → Residual → LayerNorm → FFN → Residual → Output
Why Transformers replaced RNNs: RNNs process tokens sequentially (slow). Transformers attend to all positions simultaneously (parallelizable). RNNs have vanishing gradients over long sequences. Transformers have direct connections to any position via attention.