Skip to content

nextquanta/purdue_gilbreth_resume

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Training Resumption Guide

A complete guide to implementing robust, checkpoint-based training resumption with reproducible random states for PyTorch models.

Table of Contents


Overview

This training framework provides:

  • ✅ Automatic checkpoint saving after each epoch
  • ✅ Reproducible random states for consistent data shuffling
  • ✅ Config-based checkpoint hashing to prevent loading incompatible checkpoints
  • ✅ Graceful interruption handling for SLURM/cluster time limits
  • ✅ A simple API that’s easy to integrate into existing code

Training Interruption

If training is interrupted mid-epoch and later resumed, the epoch restarts with the exact same data shuffle, producing identical batch losses. This makes debugging easier and training more predictable.


Why This Matters

Without random state saving:

First run:  Epoch 2 [0/1000] Loss: 1.993
[Job interrupted]
Second run: Epoch 2 [0/1000] Loss: 1.824  ❌ DIFFERENT!

With random state saving:

First run:  Epoch 2 [0/1000] Loss: 1.993
[Job interrupted]
Second run: Epoch 2 [0/1000] Loss: 1.993  ✅ SAME!

Why? The DataLoader with shuffle=True creates a new random permutation each epoch. When you restore random states, PyTorch generates the exact same shuffle, so batch 0 contains the same samples → same loss.


SLURM Scripts

Two scripts are provided for running training on SLURM clusters (e.g., Purdue Gilbreth). The master script resume_training.sh submits NUM_RUNS chained jobs, each invoking train_experiment.sh to continue training from the latest checkpoint.

resume_training.sh — Chain Multiple Jobs

Submits a chain of dependent SLURM jobs that automatically resume training.

Key Features:

  • Sequential execution: starts the next job only after the previous one finishes or times out (--dependency=afterany).
  • Failure tolerance: any failure (including timeouts) does not stop the chain; the next job still starts.
  • Email notification: sends an email when the final job completes.
  • Logging: appends stdout to logs/training_run.out and stderr to logs/training_run.err.

Configuration to Modify:

NUM_RUNS=10                         # CHANGE: Number of jobs to chain

# Change email for the final job
EMAIL_FLAGS="--mail-type=END --mail-user=rjaiswa@purdue.edu"  # CHANGE

# Customize job name prefix
--job-name="cifar_run_${i}"        # CHANGE: Prefix as needed

Usage:

./resume_training.sh

How It Works:

  1. Submits job 1.
  2. Submits job 2 with --dependency=afterany:JOB1_ID so it runs after job 1 regardless of exit status.
  3. Repeats for NUM_RUNS jobs.
  4. The final job sends an email notification.

Monitoring:

# Watch job queue
watch -n 2 squeue -u $USER

# View live training progress
tail -f logs/training_run.out

Example Output:

Submitted run 1/10 as job 12345
Submitted run 2/10 as job 12346
Submitted run 3/10 as job 12347
...
All 10 jobs submitted. Last job ID: 12354
You'll receive an email when job 12354 completes.

train_experiment.sh — Single Training Job

Runs a single training session with automatic checkpoint resumption.

Key Features:

  • Timeout handling: distinguishes timeouts from real failures. (Not full proof)
  • Email on real failures only (not timeouts). (Not full proof)
  • Logging: appends to logs/training_run.out and logs/training_run.err.

Configuration to Modify:

#SBATCH --job-name=qat              # CHANGE: Your job name
#SBATCH -A jgmakin                  # CHANGE: Your allocation account
#SBATCH -p a100-40gb                # Change partition as needed
#SBATCH --time=00:02:00             # CHANGE: Appropriate time limit
#SBATCH --output=logs/training_run.out
#SBATCH --error=logs/training_run.err

# Change email in the script body for real failures
echo "..." | mail -s "..." rjaiswa@purdue.edu  # CHANGE: Your email

# Change conda environment
conda activate myproj               # CHANGE: Your environment name

Usage:

sbatch train_experiment.sh

Exit Codes:

  • 0 — Success or treated-timeout (checkpoint saved).
  • 137/143 — SIGKILL/SIGTERM (usually timeout) → normalized to 0.
  • Other — Real failure → email sent.

Quick Start

1. Copy the Core Functions

Add these functions to your training script:

import os
import random
import numpy as np
import torch
import hashlib
import json

def get_random_states():
    """Capture all random number generator states."""
    states = {
        'python': random.getstate(),
        'numpy': np.random.get_state(),
        'torch': torch.get_rng_state().cpu(),  # Ensure CPU tensor
    }
    if torch.cuda.is_available():
        states['torch_cuda'] = [state.cpu() for state in torch.cuda.get_rng_state_all()]
    else:
        states['torch_cuda'] = None
    return states

def set_random_states(random_states):
    """Restore all random number generator states."""
    if random_states is None:
        return
    try:
        random.setstate(random_states['python'])
        np.random.set_state(random_states['numpy'])
        torch.set_rng_state(random_states['torch'].cpu())
        if random_states.get('torch_cuda') is not None and torch.cuda.is_available():
            for i, state in enumerate(random_states['torch_cuda']):
                torch.cuda.set_rng_state(state.cpu(), i)
    except Exception as e:
        print(f"Warning: Could not restore random states: {e}")
        print("Continuing with current random state...")

def get_config_hash(config):
    """Create a unique hash from the config to identify the checkpoint."""
    config_str = json.dumps(config, sort_keys=True)
    return hashlib.md5(config_str.encode()).hexdigest()  # Full 128-bit MD5 (32 hex chars)

def get_checkpoint_path(config, checkpoint_dir):
    """Generate the checkpoint filename based on the config hash."""
    config_hash = get_config_hash(config)
    return os.path.join(checkpoint_dir, f"checkpoint_{config_hash}.pth")

def save_checkpoint(state, checkpoint_path):
    """Save a checkpoint with all training state."""
    print(f"Saving checkpoint to {checkpoint_path}")
    torch.save(state, checkpoint_path)

def load_checkpoint(checkpoint_path, model, optimizer, scheduler, device):
    """Load a checkpoint and return epoch, config, best_acc, and random states."""
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    random_states = checkpoint.get('random_states', None)
    return checkpoint['epoch'], checkpoint['config'], checkpoint.get('best_acc', 0.0), random_states

def configs_match(config1, config2):
    """Check if two configs are identical."""
    return get_config_hash(config1) == get_config_hash(config2)

2. Prepare Your Config

Store all hyperparameters in a dictionary:

config = {
    'epochs': 200,
    'batch_size': 128,
    'lr': 0.1,
    'optimizer': 'sgd',
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'scheduler': 'cosine',
    'num_workers': 4,
    'checkpoint_dir': './checkpoints'
}

Important: Keep this config consistent. Changing any value creates a new checkpoint file.

3. Initialize Model, Optimizer, Scheduler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create model
model = YourModel().to(device)

# Create optimizer
if config['optimizer'] == 'sgd':
    optimizer = optim.SGD(
        model.parameters(), lr=config['lr'], momentum=config['momentum'], weight_decay=config['weight_decay']
    )
elif config['optimizer'] == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

# Create scheduler (optional)
scheduler = None
if config.get('scheduler') == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
elif config.get('scheduler') == 'step':
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

4. Modify Your Training Loop

from pathlib import Path

# Create checkpoint directory
checkpoint_dir = config['checkpoint_dir']
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

# Before training loop: check for existing checkpoint
checkpoint_path = get_checkpoint_path(config, checkpoint_dir)

start_epoch = 0
best_acc = 0.0

if os.path.exists(checkpoint_path):
    try:
        loaded_epoch, loaded_config, loaded_best_acc, random_states = load_checkpoint(
            checkpoint_path, model, optimizer, scheduler, device
        )

        if configs_match(config, loaded_config):
            print("✓ Checkpoint found with matching hyperparameters. Resuming training.")
            start_epoch = loaded_epoch + 1
            best_acc = loaded_best_acc
            print(f"  Resuming from epoch {start_epoch}, best acc: {best_acc:.2f}%")

            # Restore random states for reproducible data loading
            if random_states is not None:
                set_random_states(random_states)
                print("  Random states restored (DataLoader will use the same shuffling)")
        else:
            print("⚠ Checkpoint found but hyperparameters differ. Starting fresh.")
    except Exception as e:
        print(f"⚠ Error loading checkpoint: {e}. Starting fresh.")
        start_epoch = 0
        best_acc = 0.0

# Training loop
for epoch in range(start_epoch, config['epochs']):
    # Train for one epoch
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch)

    # Validate
    val_loss, val_acc = validate(model, test_loader, criterion, device)

    # Update best accuracy
    if val_acc > best_acc:
        best_acc = val_acc

    # Update scheduler
    if scheduler is not None:
        scheduler.step()

    # Save checkpoint at end of epoch
    checkpoint_state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'best_acc': best_acc,
        'config': config,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'random_states': get_random_states(),  # Save random states
    }
    save_checkpoint(checkpoint_state, checkpoint_path)

Core Components

Random State Management

  • get_random_states() — captures Python, NumPy, and PyTorch CPU/CUDA RNG states.
  • set_random_states(states) — restores all RNG states (gracefully handles errors).

Checkpoint Management

  • get_config_hash(config) — creates a full 128-bit MD5 hash (32 hex chars) from the config dict.
  • get_checkpoint_path(config, dir) — generates a unique checkpoint filename based on the config hash.
  • save_checkpoint(state, path) — saves a checkpoint to disk using torch.save().
  • load_checkpoint(path, model, optimizer, scheduler, device) — loads a checkpoint and restores model/optimizer/scheduler states (in place).
  • configs_match(config1, config2) — validates config compatibility by comparing hashes.

Key point: the config hash prevents accidentally loading checkpoints trained with different hyperparameters.


Troubleshooting

Issue: Checkpoint from a different config loads

Solution: the code should detect this automatically via configs_match(). If not, check:

if configs_match(config, loaded_config):
    # Safe to resume
else:
    # Should start fresh

Issue: Training starts from epoch 0 despite having a checkpoint

Possible causes:

  1. Checkpoint loading exception
  2. Config mismatch
  3. Corrupted checkpoint file

Debug:

try:
    loaded_epoch, _, _, _ = load_checkpoint(...)
    print(f"Loaded epoch: {loaded_epoch}")
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

Issue: Out of memory after loading a checkpoint

Solution: ensure you are not creating the model twice. Checkpoint loading modifies the existing model in place.

# WRONG
model = Model()
checkpoint = load_checkpoint(...)
model = Model()  # Don't recreate!

# CORRECT
model = Model()
checkpoint = load_checkpoint(checkpoint_path, model, ...)  # Modifies model in place

Advanced Topics

Custom Checkpoint Keys

You can add any additional information to checkpoints. Include extra keys when saving and read them when loading:

checkpoint_state = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_acc': best_acc,
    'config': config,
    'random_states': get_random_states(),

    # Custom keys
    'patience': patience,
    'training_time': total_time,
    'gpu_memory': torch.cuda.max_memory_allocated(),
    'experiment_name': 'my_experiment',
    'git_commit': get_git_commit_hash(),
}

Distributed Training

For multi-GPU training, only save from rank 0:

if dist.get_rank() == 0:
    save_checkpoint(state, checkpoint_path)

Ensure all ranks load the checkpoint:

checkpoint = load_checkpoint(checkpoint_path, model, optimizer, scheduler, device)
# Broadcast random states to all ranks for consistency
if dist.get_rank() == 0:
    random_states = checkpoint['random_states']
else:
    random_states = None
random_states = dist.broadcast_object_list([random_states], src=0)[0]
set_random_states(random_states)

Summary

What You Get

  • ✅ Automatic checkpoint save/restore
  • ✅ Reproducible data shuffling via random state preservation
  • ✅ Config validation to prevent loading incompatible checkpoints
  • ✅ Simple, clean API
  • ✅ Graceful error handling

What to Remember

  1. Always save random states in checkpoints: 'random_states': get_random_states()
  2. Always restore random states after loading: set_random_states(random_states)
  3. Keep your config consistent — changing it creates a new checkpoint
  4. Save checkpoints at epoch boundaries for simplicity
  5. Test with forced timeouts to verify that resumption works

Reference Implementation See train.py for a complete working example with ResNet‑20 on CIFAR‑10.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors