A complete guide to implementing robust, checkpoint-based training resumption with reproducible random states for PyTorch models.
- Overview
- Why This Matters
- SLURM Scripts
- Quick Start
- Core Components
- Troubleshooting
- Advanced Topics
- Summary
This training framework provides:
- ✅ Automatic checkpoint saving after each epoch
- ✅ Reproducible random states for consistent data shuffling
- ✅ Config-based checkpoint hashing to prevent loading incompatible checkpoints
- ✅ Graceful interruption handling for SLURM/cluster time limits
- ✅ A simple API that’s easy to integrate into existing code
If training is interrupted mid-epoch and later resumed, the epoch restarts with the exact same data shuffle, producing identical batch losses. This makes debugging easier and training more predictable.
Without random state saving:
First run: Epoch 2 [0/1000] Loss: 1.993
[Job interrupted]
Second run: Epoch 2 [0/1000] Loss: 1.824 ❌ DIFFERENT!
With random state saving:
First run: Epoch 2 [0/1000] Loss: 1.993
[Job interrupted]
Second run: Epoch 2 [0/1000] Loss: 1.993 ✅ SAME!
Why? The DataLoader with shuffle=True creates a new random permutation each epoch. When you restore random states, PyTorch generates the exact same shuffle, so batch 0 contains the same samples → same loss.
Two scripts are provided for running training on SLURM clusters (e.g., Purdue Gilbreth). The master script resume_training.sh submits NUM_RUNS chained jobs, each invoking train_experiment.sh to continue training from the latest checkpoint.
Submits a chain of dependent SLURM jobs that automatically resume training.
Key Features:
- Sequential execution: starts the next job only after the previous one finishes or times out (
--dependency=afterany). - Failure tolerance: any failure (including timeouts) does not stop the chain; the next job still starts.
- Email notification: sends an email when the final job completes.
- Logging: appends stdout to
logs/training_run.outand stderr tologs/training_run.err.
Configuration to Modify:
NUM_RUNS=10 # CHANGE: Number of jobs to chain
# Change email for the final job
EMAIL_FLAGS="--mail-type=END --mail-user=rjaiswa@purdue.edu" # CHANGE
# Customize job name prefix
--job-name="cifar_run_${i}" # CHANGE: Prefix as neededUsage:
./resume_training.shHow It Works:
- Submits job 1.
- Submits job 2 with
--dependency=afterany:JOB1_IDso it runs after job 1 regardless of exit status. - Repeats for
NUM_RUNSjobs. - The final job sends an email notification.
Monitoring:
# Watch job queue
watch -n 2 squeue -u $USER
# View live training progress
tail -f logs/training_run.outExample Output:
Submitted run 1/10 as job 12345
Submitted run 2/10 as job 12346
Submitted run 3/10 as job 12347
...
All 10 jobs submitted. Last job ID: 12354
You'll receive an email when job 12354 completes.
Runs a single training session with automatic checkpoint resumption.
Key Features:
- Timeout handling: distinguishes timeouts from real failures. (Not full proof)
- Email on real failures only (not timeouts). (Not full proof)
- Logging: appends to
logs/training_run.outandlogs/training_run.err.
Configuration to Modify:
#SBATCH --job-name=qat # CHANGE: Your job name
#SBATCH -A jgmakin # CHANGE: Your allocation account
#SBATCH -p a100-40gb # Change partition as needed
#SBATCH --time=00:02:00 # CHANGE: Appropriate time limit
#SBATCH --output=logs/training_run.out
#SBATCH --error=logs/training_run.err
# Change email in the script body for real failures
echo "..." | mail -s "..." rjaiswa@purdue.edu # CHANGE: Your email
# Change conda environment
conda activate myproj # CHANGE: Your environment nameUsage:
sbatch train_experiment.shExit Codes:
0— Success or treated-timeout (checkpoint saved).137/143— SIGKILL/SIGTERM (usually timeout) → normalized to 0.- Other — Real failure → email sent.
Add these functions to your training script:
import os
import random
import numpy as np
import torch
import hashlib
import json
def get_random_states():
"""Capture all random number generator states."""
states = {
'python': random.getstate(),
'numpy': np.random.get_state(),
'torch': torch.get_rng_state().cpu(), # Ensure CPU tensor
}
if torch.cuda.is_available():
states['torch_cuda'] = [state.cpu() for state in torch.cuda.get_rng_state_all()]
else:
states['torch_cuda'] = None
return states
def set_random_states(random_states):
"""Restore all random number generator states."""
if random_states is None:
return
try:
random.setstate(random_states['python'])
np.random.set_state(random_states['numpy'])
torch.set_rng_state(random_states['torch'].cpu())
if random_states.get('torch_cuda') is not None and torch.cuda.is_available():
for i, state in enumerate(random_states['torch_cuda']):
torch.cuda.set_rng_state(state.cpu(), i)
except Exception as e:
print(f"Warning: Could not restore random states: {e}")
print("Continuing with current random state...")
def get_config_hash(config):
"""Create a unique hash from the config to identify the checkpoint."""
config_str = json.dumps(config, sort_keys=True)
return hashlib.md5(config_str.encode()).hexdigest() # Full 128-bit MD5 (32 hex chars)
def get_checkpoint_path(config, checkpoint_dir):
"""Generate the checkpoint filename based on the config hash."""
config_hash = get_config_hash(config)
return os.path.join(checkpoint_dir, f"checkpoint_{config_hash}.pth")
def save_checkpoint(state, checkpoint_path):
"""Save a checkpoint with all training state."""
print(f"Saving checkpoint to {checkpoint_path}")
torch.save(state, checkpoint_path)
def load_checkpoint(checkpoint_path, model, optimizer, scheduler, device):
"""Load a checkpoint and return epoch, config, best_acc, and random states."""
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
random_states = checkpoint.get('random_states', None)
return checkpoint['epoch'], checkpoint['config'], checkpoint.get('best_acc', 0.0), random_states
def configs_match(config1, config2):
"""Check if two configs are identical."""
return get_config_hash(config1) == get_config_hash(config2)Store all hyperparameters in a dictionary:
config = {
'epochs': 200,
'batch_size': 128,
'lr': 0.1,
'optimizer': 'sgd',
'momentum': 0.9,
'weight_decay': 5e-4,
'scheduler': 'cosine',
'num_workers': 4,
'checkpoint_dir': './checkpoints'
}Important: Keep this config consistent. Changing any value creates a new checkpoint file.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model
model = YourModel().to(device)
# Create optimizer
if config['optimizer'] == 'sgd':
optimizer = optim.SGD(
model.parameters(), lr=config['lr'], momentum=config['momentum'], weight_decay=config['weight_decay']
)
elif config['optimizer'] == 'adam':
optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
# Create scheduler (optional)
scheduler = None
if config.get('scheduler') == 'cosine':
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
elif config.get('scheduler') == 'step':
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)from pathlib import Path
# Create checkpoint directory
checkpoint_dir = config['checkpoint_dir']
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
# Before training loop: check for existing checkpoint
checkpoint_path = get_checkpoint_path(config, checkpoint_dir)
start_epoch = 0
best_acc = 0.0
if os.path.exists(checkpoint_path):
try:
loaded_epoch, loaded_config, loaded_best_acc, random_states = load_checkpoint(
checkpoint_path, model, optimizer, scheduler, device
)
if configs_match(config, loaded_config):
print("✓ Checkpoint found with matching hyperparameters. Resuming training.")
start_epoch = loaded_epoch + 1
best_acc = loaded_best_acc
print(f" Resuming from epoch {start_epoch}, best acc: {best_acc:.2f}%")
# Restore random states for reproducible data loading
if random_states is not None:
set_random_states(random_states)
print(" Random states restored (DataLoader will use the same shuffling)")
else:
print("⚠ Checkpoint found but hyperparameters differ. Starting fresh.")
except Exception as e:
print(f"⚠ Error loading checkpoint: {e}. Starting fresh.")
start_epoch = 0
best_acc = 0.0
# Training loop
for epoch in range(start_epoch, config['epochs']):
# Train for one epoch
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
# Validate
val_loss, val_acc = validate(model, test_loader, criterion, device)
# Update best accuracy
if val_acc > best_acc:
best_acc = val_acc
# Update scheduler
if scheduler is not None:
scheduler.step()
# Save checkpoint at end of epoch
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'best_acc': best_acc,
'config': config,
'train_loss': train_loss,
'train_acc': train_acc,
'val_loss': val_loss,
'val_acc': val_acc,
'random_states': get_random_states(), # Save random states
}
save_checkpoint(checkpoint_state, checkpoint_path)get_random_states()— captures Python, NumPy, and PyTorch CPU/CUDA RNG states.set_random_states(states)— restores all RNG states (gracefully handles errors).
get_config_hash(config)— creates a full 128-bit MD5 hash (32 hex chars) from the config dict.get_checkpoint_path(config, dir)— generates a unique checkpoint filename based on the config hash.save_checkpoint(state, path)— saves a checkpoint to disk usingtorch.save().load_checkpoint(path, model, optimizer, scheduler, device)— loads a checkpoint and restores model/optimizer/scheduler states (in place).configs_match(config1, config2)— validates config compatibility by comparing hashes.
Key point: the config hash prevents accidentally loading checkpoints trained with different hyperparameters.
Solution: the code should detect this automatically via configs_match(). If not, check:
if configs_match(config, loaded_config):
# Safe to resume
else:
# Should start freshPossible causes:
- Checkpoint loading exception
- Config mismatch
- Corrupted checkpoint file
Debug:
try:
loaded_epoch, _, _, _ = load_checkpoint(...)
print(f"Loaded epoch: {loaded_epoch}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()Solution: ensure you are not creating the model twice. Checkpoint loading modifies the existing model in place.
# WRONG
model = Model()
checkpoint = load_checkpoint(...)
model = Model() # Don't recreate!
# CORRECT
model = Model()
checkpoint = load_checkpoint(checkpoint_path, model, ...) # Modifies model in placeYou can add any additional information to checkpoints. Include extra keys when saving and read them when loading:
checkpoint_state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
'config': config,
'random_states': get_random_states(),
# Custom keys
'patience': patience,
'training_time': total_time,
'gpu_memory': torch.cuda.max_memory_allocated(),
'experiment_name': 'my_experiment',
'git_commit': get_git_commit_hash(),
}For multi-GPU training, only save from rank 0:
if dist.get_rank() == 0:
save_checkpoint(state, checkpoint_path)Ensure all ranks load the checkpoint:
checkpoint = load_checkpoint(checkpoint_path, model, optimizer, scheduler, device)
# Broadcast random states to all ranks for consistency
if dist.get_rank() == 0:
random_states = checkpoint['random_states']
else:
random_states = None
random_states = dist.broadcast_object_list([random_states], src=0)[0]
set_random_states(random_states)What You Get
- ✅ Automatic checkpoint save/restore
- ✅ Reproducible data shuffling via random state preservation
- ✅ Config validation to prevent loading incompatible checkpoints
- ✅ Simple, clean API
- ✅ Graceful error handling
What to Remember
- Always save random states in checkpoints:
'random_states': get_random_states() - Always restore random states after loading:
set_random_states(random_states) - Keep your config consistent — changing it creates a new checkpoint
- Save checkpoints at epoch boundaries for simplicity
- Test with forced timeouts to verify that resumption works
Reference Implementation
See train.py for a complete working example with ResNet‑20 on CIFAR‑10.