-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgenerate_judge_training_data.py
More file actions
120 lines (94 loc) · 3.08 KB
/
generate_judge_training_data.py
File metadata and controls
120 lines (94 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import concurrent.futures
import json
import os
import random
from dotenv import load_dotenv
from tqdm import tqdm
from agents import BaseDebateAgent
from agents import DebateAgent
from agents import run_debate
from models import DebateSample
from models import MMLUProDebateSample
random.seed(314)
load_dotenv()
BLIND_TEMPLATE = """Question:
{question}
A) {first_option}
B) {second_option}
Please answer the question. Say "A" or "B" and NOTHING ELSE."""
TEMPLATE = """Question:
{question}
A) {first_option}
B) {second_option}
Below is a debate with one side arguing for A and the other side arguing for B.
{debate_so_far}
Read the debate carefully and then answer the question. Say "A" or "B" and NOTHING ELSE."""
openai_api_key = os.getenv("OPENAI_API_KEY")
assert openai_api_key is not None, "The OPENAI_API_KEY environment variable is not set."
with open("binarised_mmlu_pro.json") as f:
samples = [MMLUProDebateSample.model_validate(sample) for sample in json.load(f)]
agent = DebateAgent.from_model(
model="gpt-4o-2024-08-06",
api_key=openai_api_key,
temperature=0.0,
)
MAX_SAMPLES = 500
samples = samples[:MAX_SAMPLES]
pbar = tqdm(total=len(samples))
def create_debate_sample(
mmlu_pro_sample: MMLUProDebateSample,
agent: BaseDebateAgent,
number_of_turns: int = 6,
):
reverse_labels = random.choice([True, False])
if reverse_labels:
first_option = mmlu_pro_sample.incorrect_answer
second_option = mmlu_pro_sample.correct_answer
label = "B"
else:
first_option = mmlu_pro_sample.correct_answer
second_option = mmlu_pro_sample.incorrect_answer
label = "A"
blind_judge_prompt = BLIND_TEMPLATE.format(
question=mmlu_pro_sample.question,
first_option=first_option,
second_option=second_option,
)
full_debate = run_debate(
question=mmlu_pro_sample.question,
position=first_option,
opposing_position=second_option,
agent=agent,
opponent_agent=agent,
number_of_turns=number_of_turns,
)
speakers = [
"Person advocating for A: ",
"Person advocating for B: ",
] * number_of_turns
debate_so_far = "".join(
f"\n\n{speaker}\n{turn}" for speaker, turn in zip(speakers, full_debate)
)
judge_prompt = TEMPLATE.format(
question=mmlu_pro_sample.question,
first_option=first_option,
second_option=second_option,
debate_so_far=debate_so_far,
)
pbar.update(1)
return DebateSample(
question=mmlu_pro_sample.question,
first_answer=first_option,
second_answer=second_option,
label=label,
blind_judge_prompt=blind_judge_prompt,
judge_prompt=judge_prompt,
)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
debate_samples = list(
executor.map(lambda sample: create_debate_sample(sample, agent), samples)
)
pbar.close()
with open("debate_samples_gpt4o_08_06.json", "w") as f:
json.dump([sample.model_dump() for sample in debate_samples], f)
print(f"Total cost: {agent.llm.total_cost()}")