Skip to content

Commit bdedee7

Browse files
CarlG0123Carl Grossclaude
authored
Add MMLU benchmark evaluation to evals (#1183)
* adding MMLU to evals, updating corresponding tests * Fix docstring tests for MMLU functions Skip MMLU docstring examples in doctest runs since they require network access (HuggingFace dataset download) and may require GPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Remove device parameter from mmlu_eval, use model.cfg.device instead Address PR review feedback: use model.cfg.device internally instead of accepting a device parameter, consistent with ioi_eval. This prevents device mismatch errors when users forget to pass the correct device. --------- Co-authored-by: Carl Gross <carl.gross@accenture.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a494811 commit bdedee7

2 files changed

Lines changed: 357 additions & 2 deletions

File tree

tests/acceptance/test_evals.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import pytest
22

3-
from transformer_lens.evals import IOIDataset, ioi_eval
3+
from transformer_lens.evals import (
4+
IOIDataset,
5+
ioi_eval,
6+
make_mmlu_data_loader,
7+
mmlu_eval,
8+
)
49
from transformer_lens.HookedTransformer import HookedTransformer
510

611

@@ -70,3 +75,65 @@ def test_inverted_template(model):
7075
results = ioi_eval(model, dataset=ds)
7176
assert results["Logit Difference"] < -2.0
7277
assert results["Accuracy"] <= 0.01
78+
79+
80+
def test_mmlu_data_loader_single_subject():
81+
"""
82+
Test loading MMLU data for a single subject.
83+
"""
84+
data = make_mmlu_data_loader(subjects="abstract_algebra", num_samples=5)
85+
assert len(data) == 5
86+
assert all(isinstance(d, dict) for d in data)
87+
assert all("question" in d for d in data)
88+
assert all("choices" in d for d in data)
89+
assert all("answer" in d for d in data)
90+
assert all("subject" in d for d in data)
91+
assert all(len(d["choices"]) == 4 for d in data)
92+
assert all(d["subject"] == "abstract_algebra" for d in data)
93+
94+
95+
def test_mmlu_data_loader_multiple_subjects():
96+
"""
97+
Test loading MMLU data for multiple subjects.
98+
"""
99+
subjects = ["abstract_algebra", "anatomy"]
100+
data = make_mmlu_data_loader(subjects=subjects, num_samples=3)
101+
assert len(data) == 6 # 3 samples per subject
102+
subjects_in_data = {d["subject"] for d in data}
103+
assert subjects_in_data == set(subjects)
104+
105+
106+
def test_mmlu_data_loader_invalid_subject():
107+
"""
108+
Test that invalid subject names raise an error.
109+
"""
110+
with pytest.raises(ValueError, match="Invalid subject"):
111+
make_mmlu_data_loader(subjects="invalid_subject_name")
112+
113+
114+
def test_mmlu_eval_single_subject(model):
115+
"""
116+
Test MMLU evaluation on a single subject with a small number of samples.
117+
Uses a small model and few samples for fast CI execution.
118+
"""
119+
results = mmlu_eval(model, subjects="abstract_algebra", num_samples=5)
120+
assert "accuracy" in results
121+
assert "num_correct" in results
122+
assert "num_total" in results
123+
assert "subject_scores" in results
124+
assert 0 <= results["accuracy"] <= 1
125+
assert results["num_total"] == 5
126+
assert results["num_correct"] <= results["num_total"]
127+
assert "abstract_algebra" in results["subject_scores"]
128+
129+
130+
def test_mmlu_eval_multiple_subjects(model):
131+
"""
132+
Test MMLU evaluation on multiple subjects.
133+
"""
134+
subjects = ["abstract_algebra", "anatomy"]
135+
results = mmlu_eval(model, subjects=subjects, num_samples=3)
136+
assert results["num_total"] == 6 # 3 samples per subject
137+
assert len(results["subject_scores"]) == 2
138+
assert all(subject in results["subject_scores"] for subject in subjects)
139+
assert all(0 <= acc <= 1 for acc in results["subject_scores"].values())

transformer_lens/evals.py

Lines changed: 289 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import random
9-
from typing import Dict, List, Optional
9+
from typing import Dict, List, Optional, Union
1010

1111
import einops
1212
import torch
@@ -85,6 +85,160 @@ def make_code_data_loader(tokenizer, batch_size=8):
8585
return data_loader
8686

8787

88+
# All 57 subjects available in the MMLU benchmark
89+
MMLU_SUBJECTS = [
90+
"abstract_algebra",
91+
"anatomy",
92+
"astronomy",
93+
"business_ethics",
94+
"clinical_knowledge",
95+
"college_biology",
96+
"college_chemistry",
97+
"college_computer_science",
98+
"college_mathematics",
99+
"college_medicine",
100+
"college_physics",
101+
"computer_security",
102+
"conceptual_physics",
103+
"econometrics",
104+
"electrical_engineering",
105+
"elementary_mathematics",
106+
"formal_logic",
107+
"global_facts",
108+
"high_school_biology",
109+
"high_school_chemistry",
110+
"high_school_computer_science",
111+
"high_school_european_history",
112+
"high_school_geography",
113+
"high_school_government_and_politics",
114+
"high_school_macroeconomics",
115+
"high_school_mathematics",
116+
"high_school_microeconomics",
117+
"high_school_physics",
118+
"high_school_psychology",
119+
"high_school_statistics",
120+
"high_school_us_history",
121+
"high_school_world_history",
122+
"human_aging",
123+
"human_sexuality",
124+
"international_law",
125+
"jurisprudence",
126+
"logical_fallacies",
127+
"machine_learning",
128+
"management",
129+
"marketing",
130+
"medical_genetics",
131+
"miscellaneous",
132+
"moral_disputes",
133+
"moral_scenarios",
134+
"nutrition",
135+
"philosophy",
136+
"prehistory",
137+
"professional_accounting",
138+
"professional_law",
139+
"professional_medicine",
140+
"professional_psychology",
141+
"public_relations",
142+
"security_studies",
143+
"sociology",
144+
"us_foreign_policy",
145+
"virology",
146+
"world_religions",
147+
]
148+
149+
MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"]
150+
151+
152+
def make_mmlu_data_loader(
153+
subjects: Optional[Union[str, List[str]]] = None,
154+
split: str = "test",
155+
num_samples: Optional[int] = None,
156+
):
157+
"""
158+
Load MMLU (Massive Multitask Language Understanding) dataset.
159+
160+
MMLU tests model performance on 57 subjects across STEM, humanities, social sciences,
161+
and more. Each question is multiple choice with 4 options (A, B, C, D).
162+
163+
Paper: https://arxiv.org/abs/2009.03300
164+
Dataset: https://huggingface.co/datasets/cais/mmlu
165+
166+
Args:
167+
subjects: Subject(s) to evaluate on. Can be:
168+
- None: Use all 57 subjects (default)
169+
- str: Single subject name (e.g., "abstract_algebra")
170+
- List[str]: Multiple subjects
171+
split: Which split to use - "test", "validation", or "dev". Default is "test".
172+
num_samples: Optional limit on number of samples per subject. If None, uses all samples.
173+
174+
Returns:
175+
List of dictionaries with MMLU examples, each containing:
176+
- "question": str
177+
- "choices": List[str] (4 choices)
178+
- "answer": int (0-3, correct choice index)
179+
- "subject": str
180+
181+
Examples:
182+
183+
.. code-block:: python
184+
185+
>>> from transformer_lens.evals import make_mmlu_data_loader
186+
187+
>>> # Load specific subject
188+
>>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP
189+
190+
>>> # Load multiple subjects
191+
>>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP
192+
... subjects=["abstract_algebra", "astronomy", "college_chemistry"]
193+
... )
194+
"""
195+
# Handle subjects parameter
196+
if subjects is None:
197+
subjects_to_load = MMLU_SUBJECTS
198+
elif isinstance(subjects, str):
199+
subjects_to_load = [subjects]
200+
else:
201+
subjects_to_load = list(subjects)
202+
203+
# Validate subjects
204+
invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS)
205+
if invalid_subjects:
206+
raise ValueError(
207+
f"Invalid subject(s): {invalid_subjects}. "
208+
f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}"
209+
)
210+
211+
# Load data for each subject
212+
mmlu_data = []
213+
for subject in subjects_to_load:
214+
try:
215+
# Load dataset for this subject
216+
dataset = load_dataset("cais/mmlu", subject, split=split)
217+
218+
# Limit samples if requested
219+
samples_to_take = (
220+
len(dataset) if num_samples is None else min(num_samples, len(dataset))
221+
)
222+
223+
# Convert to our format
224+
for i in range(samples_to_take):
225+
example = dataset[i]
226+
mmlu_data.append(
227+
{
228+
"question": example["question"],
229+
"choices": example["choices"],
230+
"answer": example["answer"],
231+
"subject": subject,
232+
}
233+
)
234+
except Exception as e:
235+
print(f"Warning: Could not load subject '{subject}': {e}")
236+
continue
237+
238+
print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)")
239+
return mmlu_data
240+
241+
88242
DATASET_NAMES = ["wiki", "owt", "pile", "code"]
89243
DATASET_LOADERS = [
90244
make_wiki_data_loader,
@@ -334,3 +488,137 @@ def collate(samples):
334488
"Logit Difference": total_logit_diff / len(dataset),
335489
"Accuracy": total_correct / len(dataset),
336490
}
491+
492+
493+
@torch.inference_mode()
494+
def mmlu_eval(
495+
model,
496+
tokenizer=None,
497+
subjects: Optional[Union[str, List[str]]] = None,
498+
split: str = "test",
499+
num_samples: Optional[int] = None,
500+
):
501+
"""Evaluate a model on the MMLU benchmark.
502+
503+
MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models
504+
on 57 subjects across STEM, humanities, social sciences, and more. Each question is
505+
multiple-choice with 4 options.
506+
507+
For each question, all four answer choices (A-D) are shown in the prompt and the model's
508+
log probability for each answer letter token is compared. This is a zero-shot evaluation;
509+
standard MMLU benchmarks typically use 5-shot prompting for higher accuracy.
510+
511+
Paper: https://arxiv.org/abs/2009.03300
512+
513+
Args:
514+
model: HookedTransformer model to evaluate.
515+
tokenizer: Tokenizer to use. If None, uses model.tokenizer.
516+
subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject
517+
string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names.
518+
split: Which split to use - "test", "validation", or "dev". Default is "test".
519+
num_samples: Optional limit on number of samples per subject. If None, uses all samples.
520+
521+
Returns:
522+
Dictionary containing:
523+
- "accuracy": Overall accuracy (0-1)
524+
- "num_correct": Number of correct predictions
525+
- "num_total": Total number of questions
526+
- "subject_scores": Dict mapping subject names to their accuracy
527+
528+
Examples:
529+
530+
.. code-block:: python
531+
532+
>>> from transformer_lens import HookedTransformer
533+
>>> from transformer_lens.evals import mmlu_eval
534+
535+
>>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
536+
>>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP
537+
>>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP
538+
"""
539+
if tokenizer is None:
540+
tokenizer = model.tokenizer
541+
542+
# Load MMLU data
543+
mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples)
544+
545+
if len(mmlu_data) == 0:
546+
raise ValueError("No MMLU data loaded. Check your subjects parameter.")
547+
548+
# Precompute token IDs for answer letters A, B, C, D
549+
# Done once here instead of per-question for efficiency
550+
answer_letter_token_ids = []
551+
for letter in MMLU_ANSWER_LETTERS:
552+
# Try with space prefix first (how it appears after "Answer:")
553+
token_ids = tokenizer.encode(" " + letter, add_special_tokens=False)
554+
if len(token_ids) == 1:
555+
answer_letter_token_ids.append(token_ids[0])
556+
else:
557+
# Fallback to without space
558+
token_ids = tokenizer.encode(letter, add_special_tokens=False)
559+
answer_letter_token_ids.append(token_ids[0])
560+
561+
# Track results
562+
num_correct = 0
563+
num_total = 0
564+
subject_correct: Dict[str, int] = {}
565+
subject_total: Dict[str, int] = {}
566+
567+
# Process examples
568+
for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"):
569+
question = example["question"]
570+
choices = example["choices"]
571+
correct_answer = example["answer"]
572+
subject = example["subject"]
573+
574+
# Initialize subject tracking
575+
if subject not in subject_correct:
576+
subject_correct[subject] = 0
577+
subject_total[subject] = 0
578+
579+
# Format prompt with all choices shown (standard MMLU format)
580+
prompt = f"Question: {question}\n"
581+
prompt += "Choices:\n"
582+
for idx, choice_text in enumerate(choices):
583+
letter = chr(65 + idx) # A, B, C, D
584+
prompt += f"{letter}. {choice_text}\n"
585+
prompt += "Answer:"
586+
587+
# Tokenize the prompt
588+
tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device)
589+
590+
# Get logits
591+
logits = model(tokens, return_type="logits")
592+
593+
# Get log probabilities at the last position (predicting the answer letter)
594+
last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1)
595+
596+
# Score each answer choice by its letter token probability
597+
choice_log_probs = []
598+
for idx in range(len(choices)):
599+
token_id = answer_letter_token_ids[idx]
600+
choice_log_probs.append(last_log_probs[token_id].item())
601+
602+
# Select the choice with highest log probability
603+
predicted_answer = choice_log_probs.index(max(choice_log_probs))
604+
605+
# Check if correct
606+
is_correct = predicted_answer == correct_answer
607+
num_correct += int(is_correct)
608+
num_total += 1
609+
subject_correct[subject] += int(is_correct)
610+
subject_total[subject] += 1
611+
612+
# Compute accuracies
613+
overall_accuracy = num_correct / num_total if num_total > 0 else 0.0
614+
subject_scores = {
615+
subject: subject_correct[subject] / subject_total[subject]
616+
for subject in subject_correct.keys()
617+
}
618+
619+
return {
620+
"accuracy": overall_accuracy,
621+
"num_correct": num_correct,
622+
"num_total": num_total,
623+
"subject_scores": subject_scores,
624+
}

0 commit comments

Comments
 (0)