Agent skill

pytorch

Building and training neural networks with PyTorch. Use when implementing deep learning models, training loops, data pipelines, model optimization with torch.compile, distributed training, or deploying PyTorch models.

Stars 163
Forks 31

Install this agent skill to your Project

npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/development/pytorch

SKILL.md

Using PyTorch

PyTorch is a deep learning framework with dynamic computation graphs, strong GPU acceleration, and Pythonic design. This skill covers practical patterns for building production-quality neural networks.

Table of Contents

Core Concepts

Tensors

python
import torch

# Create tensors
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
x = torch.zeros(3, 4)
x = torch.randn(3, 4)  # Normal distribution

# Device management
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)

# Operations (all return new tensors)
y = x + 1
y = x @ x.T  # Matrix multiplication
y = x.view(2, 6)  # Reshape

Autograd

python
# Enable gradient tracking
x = torch.randn(3, requires_grad=True)
y = x ** 2
loss = y.sum()

# Compute gradients
loss.backward()
print(x.grad)  # dy/dx

# Disable gradients for inference
with torch.no_grad():
    pred = model(x)

# Or use inference mode (more efficient)
with torch.inference_mode():
    pred = model(x)

Model Architecture

nn.Module Pattern

python
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

Common Layers

python
# Convolution
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

# Normalization
nn.BatchNorm2d(num_features)
nn.LayerNorm(normalized_shape)

# Attention
nn.MultiheadAttention(embed_dim, num_heads)

# Recurrent
nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
nn.GRU(input_size, hidden_size, num_layers, batch_first=True)

Weight Initialization

python
def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=0.02)

model.apply(init_weights)

Training Loop

Standard Pattern

python
model = Model(input_dim, hidden_dim, output_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        # Optional: gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

    scheduler.step()

    # Validation
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            # ... validation logic

Mixed Precision Training

python
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in train_loader:
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Gradient Accumulation

python
# Requires setup from Mixed Precision Training above:
# scaler = GradScaler(), model, criterion, optimizer, device

accumulation_steps = 4

for i, batch in enumerate(train_loader):
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets) / accumulation_steps

    scaler.scale(loss).backward()

    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Data Loading

Dataset and DataLoader

python
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        if self.transform:
            x = self.transform(x)
        return x, self.labels[idx]

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,  # Faster GPU transfer
    drop_last=True,   # Consistent batch sizes
)

Collate Functions

python
def collate_fn(batch):
    """Custom batching for variable-length sequences."""
    inputs, targets = zip(*batch)
    inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    targets = torch.stack(targets)
    return inputs, targets

loader = DataLoader(dataset, collate_fn=collate_fn)

Performance Optimization

torch.compile (PyTorch 2.0+)

python
# Basic compilation
model = torch.compile(model)

# With options
model = torch.compile(
    model,
    mode="reduce-overhead",  # Options: default, reduce-overhead, max-autotune
    fullgraph=True,          # Enforce no graph breaks
)

# Compile specific functions
@torch.compile
def train_step(model, inputs, targets):
    outputs = model(inputs)
    return criterion(outputs, targets)

Compilation modes:

  • default: Good balance of compile time and speedup
  • reduce-overhead: Minimizes framework overhead, good for small models
  • max-autotune: Maximum performance, longer compile time

Memory Optimization

python
# Activation checkpointing (trade compute for memory)
from torch.utils.checkpoint import checkpoint

class Model(nn.Module):
    def forward(self, x):
        # Recompute activations during backward
        x = checkpoint(self.expensive_layer, x, use_reentrant=False)
        return self.output_layer(x)

# Clear cache
torch.cuda.empty_cache()

# Monitor memory
print(torch.cuda.memory_allocated() / 1e9, "GB")
print(torch.cuda.max_memory_allocated() / 1e9, "GB")

Distributed Training

DistributedDataParallel (DDP)

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    model = Model().to(rank)
    model = DDP(model, device_ids=[rank])

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, sampler=sampler)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Important for shuffling
        # ... training loop

    cleanup()

# Launch with: torchrun --nproc_per_node=4 train.py

FullyShardedDataParallel (FSDP)

python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    mixed_precision=mp_policy,
    use_orig_params=True,  # Required for torch.compile compatibility
)

Saving and Loading

Checkpoints

python
# Save
torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
}, "checkpoint.pt")

# Load
checkpoint = torch.load("checkpoint.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

Export for Deployment

python
# TorchScript
scripted = torch.jit.script(model)
scripted.save("model.pt")

# ONNX
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)

Best Practices

  1. Always set model mode: Use model.train() and model.eval() appropriately
  2. Use inference_mode over no_grad: More efficient for inference
  3. Pin memory for GPU training: Set pin_memory=True in DataLoader
  4. Profile before optimizing: Use torch.profiler to find bottlenecks
  5. Prefer bfloat16 over float16: Better numerical stability on modern GPUs
  6. Use torch.compile: Significant speedups with minimal code changes
  7. Set deterministic mode for reproducibility:
    python
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

References

See reference/ for detailed documentation:

  • training-patterns.md - Advanced training techniques
  • debugging.md - Debugging and profiling tools

Didn't find tool you were looking for?

Be as detailed as possible for better results