|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import random |
9 | | -from typing import Dict, List, Optional |
| 9 | +from typing import Dict, List, Optional, Union |
10 | 10 |
|
11 | 11 | import einops |
12 | 12 | import torch |
@@ -85,6 +85,160 @@ def make_code_data_loader(tokenizer, batch_size=8): |
85 | 85 | return data_loader |
86 | 86 |
|
87 | 87 |
|
| 88 | +# All 57 subjects available in the MMLU benchmark |
| 89 | +MMLU_SUBJECTS = [ |
| 90 | + "abstract_algebra", |
| 91 | + "anatomy", |
| 92 | + "astronomy", |
| 93 | + "business_ethics", |
| 94 | + "clinical_knowledge", |
| 95 | + "college_biology", |
| 96 | + "college_chemistry", |
| 97 | + "college_computer_science", |
| 98 | + "college_mathematics", |
| 99 | + "college_medicine", |
| 100 | + "college_physics", |
| 101 | + "computer_security", |
| 102 | + "conceptual_physics", |
| 103 | + "econometrics", |
| 104 | + "electrical_engineering", |
| 105 | + "elementary_mathematics", |
| 106 | + "formal_logic", |
| 107 | + "global_facts", |
| 108 | + "high_school_biology", |
| 109 | + "high_school_chemistry", |
| 110 | + "high_school_computer_science", |
| 111 | + "high_school_european_history", |
| 112 | + "high_school_geography", |
| 113 | + "high_school_government_and_politics", |
| 114 | + "high_school_macroeconomics", |
| 115 | + "high_school_mathematics", |
| 116 | + "high_school_microeconomics", |
| 117 | + "high_school_physics", |
| 118 | + "high_school_psychology", |
| 119 | + "high_school_statistics", |
| 120 | + "high_school_us_history", |
| 121 | + "high_school_world_history", |
| 122 | + "human_aging", |
| 123 | + "human_sexuality", |
| 124 | + "international_law", |
| 125 | + "jurisprudence", |
| 126 | + "logical_fallacies", |
| 127 | + "machine_learning", |
| 128 | + "management", |
| 129 | + "marketing", |
| 130 | + "medical_genetics", |
| 131 | + "miscellaneous", |
| 132 | + "moral_disputes", |
| 133 | + "moral_scenarios", |
| 134 | + "nutrition", |
| 135 | + "philosophy", |
| 136 | + "prehistory", |
| 137 | + "professional_accounting", |
| 138 | + "professional_law", |
| 139 | + "professional_medicine", |
| 140 | + "professional_psychology", |
| 141 | + "public_relations", |
| 142 | + "security_studies", |
| 143 | + "sociology", |
| 144 | + "us_foreign_policy", |
| 145 | + "virology", |
| 146 | + "world_religions", |
| 147 | +] |
| 148 | + |
| 149 | +MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"] |
| 150 | + |
| 151 | + |
| 152 | +def make_mmlu_data_loader( |
| 153 | + subjects: Optional[Union[str, List[str]]] = None, |
| 154 | + split: str = "test", |
| 155 | + num_samples: Optional[int] = None, |
| 156 | +): |
| 157 | + """ |
| 158 | + Load MMLU (Massive Multitask Language Understanding) dataset. |
| 159 | +
|
| 160 | + MMLU tests model performance on 57 subjects across STEM, humanities, social sciences, |
| 161 | + and more. Each question is multiple choice with 4 options (A, B, C, D). |
| 162 | +
|
| 163 | + Paper: https://arxiv.org/abs/2009.03300 |
| 164 | + Dataset: https://huggingface.co/datasets/cais/mmlu |
| 165 | +
|
| 166 | + Args: |
| 167 | + subjects: Subject(s) to evaluate on. Can be: |
| 168 | + - None: Use all 57 subjects (default) |
| 169 | + - str: Single subject name (e.g., "abstract_algebra") |
| 170 | + - List[str]: Multiple subjects |
| 171 | + split: Which split to use - "test", "validation", or "dev". Default is "test". |
| 172 | + num_samples: Optional limit on number of samples per subject. If None, uses all samples. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + List of dictionaries with MMLU examples, each containing: |
| 176 | + - "question": str |
| 177 | + - "choices": List[str] (4 choices) |
| 178 | + - "answer": int (0-3, correct choice index) |
| 179 | + - "subject": str |
| 180 | +
|
| 181 | + Examples: |
| 182 | +
|
| 183 | + .. code-block:: python |
| 184 | +
|
| 185 | + >>> from transformer_lens.evals import make_mmlu_data_loader |
| 186 | +
|
| 187 | + >>> # Load specific subject |
| 188 | + >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP |
| 189 | +
|
| 190 | + >>> # Load multiple subjects |
| 191 | + >>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP |
| 192 | + ... subjects=["abstract_algebra", "astronomy", "college_chemistry"] |
| 193 | + ... ) |
| 194 | + """ |
| 195 | + # Handle subjects parameter |
| 196 | + if subjects is None: |
| 197 | + subjects_to_load = MMLU_SUBJECTS |
| 198 | + elif isinstance(subjects, str): |
| 199 | + subjects_to_load = [subjects] |
| 200 | + else: |
| 201 | + subjects_to_load = list(subjects) |
| 202 | + |
| 203 | + # Validate subjects |
| 204 | + invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS) |
| 205 | + if invalid_subjects: |
| 206 | + raise ValueError( |
| 207 | + f"Invalid subject(s): {invalid_subjects}. " |
| 208 | + f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}" |
| 209 | + ) |
| 210 | + |
| 211 | + # Load data for each subject |
| 212 | + mmlu_data = [] |
| 213 | + for subject in subjects_to_load: |
| 214 | + try: |
| 215 | + # Load dataset for this subject |
| 216 | + dataset = load_dataset("cais/mmlu", subject, split=split) |
| 217 | + |
| 218 | + # Limit samples if requested |
| 219 | + samples_to_take = ( |
| 220 | + len(dataset) if num_samples is None else min(num_samples, len(dataset)) |
| 221 | + ) |
| 222 | + |
| 223 | + # Convert to our format |
| 224 | + for i in range(samples_to_take): |
| 225 | + example = dataset[i] |
| 226 | + mmlu_data.append( |
| 227 | + { |
| 228 | + "question": example["question"], |
| 229 | + "choices": example["choices"], |
| 230 | + "answer": example["answer"], |
| 231 | + "subject": subject, |
| 232 | + } |
| 233 | + ) |
| 234 | + except Exception as e: |
| 235 | + print(f"Warning: Could not load subject '{subject}': {e}") |
| 236 | + continue |
| 237 | + |
| 238 | + print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)") |
| 239 | + return mmlu_data |
| 240 | + |
| 241 | + |
88 | 242 | DATASET_NAMES = ["wiki", "owt", "pile", "code"] |
89 | 243 | DATASET_LOADERS = [ |
90 | 244 | make_wiki_data_loader, |
@@ -334,3 +488,137 @@ def collate(samples): |
334 | 488 | "Logit Difference": total_logit_diff / len(dataset), |
335 | 489 | "Accuracy": total_correct / len(dataset), |
336 | 490 | } |
| 491 | + |
| 492 | + |
| 493 | +@torch.inference_mode() |
| 494 | +def mmlu_eval( |
| 495 | + model, |
| 496 | + tokenizer=None, |
| 497 | + subjects: Optional[Union[str, List[str]]] = None, |
| 498 | + split: str = "test", |
| 499 | + num_samples: Optional[int] = None, |
| 500 | +): |
| 501 | + """Evaluate a model on the MMLU benchmark. |
| 502 | +
|
| 503 | + MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models |
| 504 | + on 57 subjects across STEM, humanities, social sciences, and more. Each question is |
| 505 | + multiple-choice with 4 options. |
| 506 | +
|
| 507 | + For each question, all four answer choices (A-D) are shown in the prompt and the model's |
| 508 | + log probability for each answer letter token is compared. This is a zero-shot evaluation; |
| 509 | + standard MMLU benchmarks typically use 5-shot prompting for higher accuracy. |
| 510 | +
|
| 511 | + Paper: https://arxiv.org/abs/2009.03300 |
| 512 | +
|
| 513 | + Args: |
| 514 | + model: HookedTransformer model to evaluate. |
| 515 | + tokenizer: Tokenizer to use. If None, uses model.tokenizer. |
| 516 | + subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject |
| 517 | + string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names. |
| 518 | + split: Which split to use - "test", "validation", or "dev". Default is "test". |
| 519 | + num_samples: Optional limit on number of samples per subject. If None, uses all samples. |
| 520 | +
|
| 521 | + Returns: |
| 522 | + Dictionary containing: |
| 523 | + - "accuracy": Overall accuracy (0-1) |
| 524 | + - "num_correct": Number of correct predictions |
| 525 | + - "num_total": Total number of questions |
| 526 | + - "subject_scores": Dict mapping subject names to their accuracy |
| 527 | +
|
| 528 | + Examples: |
| 529 | +
|
| 530 | + .. code-block:: python |
| 531 | +
|
| 532 | + >>> from transformer_lens import HookedTransformer |
| 533 | + >>> from transformer_lens.evals import mmlu_eval |
| 534 | +
|
| 535 | + >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP |
| 536 | + >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP |
| 537 | + >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP |
| 538 | + """ |
| 539 | + if tokenizer is None: |
| 540 | + tokenizer = model.tokenizer |
| 541 | + |
| 542 | + # Load MMLU data |
| 543 | + mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples) |
| 544 | + |
| 545 | + if len(mmlu_data) == 0: |
| 546 | + raise ValueError("No MMLU data loaded. Check your subjects parameter.") |
| 547 | + |
| 548 | + # Precompute token IDs for answer letters A, B, C, D |
| 549 | + # Done once here instead of per-question for efficiency |
| 550 | + answer_letter_token_ids = [] |
| 551 | + for letter in MMLU_ANSWER_LETTERS: |
| 552 | + # Try with space prefix first (how it appears after "Answer:") |
| 553 | + token_ids = tokenizer.encode(" " + letter, add_special_tokens=False) |
| 554 | + if len(token_ids) == 1: |
| 555 | + answer_letter_token_ids.append(token_ids[0]) |
| 556 | + else: |
| 557 | + # Fallback to without space |
| 558 | + token_ids = tokenizer.encode(letter, add_special_tokens=False) |
| 559 | + answer_letter_token_ids.append(token_ids[0]) |
| 560 | + |
| 561 | + # Track results |
| 562 | + num_correct = 0 |
| 563 | + num_total = 0 |
| 564 | + subject_correct: Dict[str, int] = {} |
| 565 | + subject_total: Dict[str, int] = {} |
| 566 | + |
| 567 | + # Process examples |
| 568 | + for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"): |
| 569 | + question = example["question"] |
| 570 | + choices = example["choices"] |
| 571 | + correct_answer = example["answer"] |
| 572 | + subject = example["subject"] |
| 573 | + |
| 574 | + # Initialize subject tracking |
| 575 | + if subject not in subject_correct: |
| 576 | + subject_correct[subject] = 0 |
| 577 | + subject_total[subject] = 0 |
| 578 | + |
| 579 | + # Format prompt with all choices shown (standard MMLU format) |
| 580 | + prompt = f"Question: {question}\n" |
| 581 | + prompt += "Choices:\n" |
| 582 | + for idx, choice_text in enumerate(choices): |
| 583 | + letter = chr(65 + idx) # A, B, C, D |
| 584 | + prompt += f"{letter}. {choice_text}\n" |
| 585 | + prompt += "Answer:" |
| 586 | + |
| 587 | + # Tokenize the prompt |
| 588 | + tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device) |
| 589 | + |
| 590 | + # Get logits |
| 591 | + logits = model(tokens, return_type="logits") |
| 592 | + |
| 593 | + # Get log probabilities at the last position (predicting the answer letter) |
| 594 | + last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1) |
| 595 | + |
| 596 | + # Score each answer choice by its letter token probability |
| 597 | + choice_log_probs = [] |
| 598 | + for idx in range(len(choices)): |
| 599 | + token_id = answer_letter_token_ids[idx] |
| 600 | + choice_log_probs.append(last_log_probs[token_id].item()) |
| 601 | + |
| 602 | + # Select the choice with highest log probability |
| 603 | + predicted_answer = choice_log_probs.index(max(choice_log_probs)) |
| 604 | + |
| 605 | + # Check if correct |
| 606 | + is_correct = predicted_answer == correct_answer |
| 607 | + num_correct += int(is_correct) |
| 608 | + num_total += 1 |
| 609 | + subject_correct[subject] += int(is_correct) |
| 610 | + subject_total[subject] += 1 |
| 611 | + |
| 612 | + # Compute accuracies |
| 613 | + overall_accuracy = num_correct / num_total if num_total > 0 else 0.0 |
| 614 | + subject_scores = { |
| 615 | + subject: subject_correct[subject] / subject_total[subject] |
| 616 | + for subject in subject_correct.keys() |
| 617 | + } |
| 618 | + |
| 619 | + return { |
| 620 | + "accuracy": overall_accuracy, |
| 621 | + "num_correct": num_correct, |
| 622 | + "num_total": num_total, |
| 623 | + "subject_scores": subject_scores, |
| 624 | + } |
0 commit comments