-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrunner.py
More file actions
161 lines (129 loc) · 6.9 KB
/
runner.py
File metadata and controls
161 lines (129 loc) · 6.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import random
from typing import List, Union
from ruamel.yaml import YAML
import copy
import argparse
import json
from typing import Dict
import numpy as np
import torch
from tqdm import tqdm
from transformers import GenerationConfig
from tokenizers import Tokenizer
from progen_conditional.model import ProgenConditional
from progen_conditional.data import get_tokenizer, PAD_TOKEN_ID
from scripts.utils import taxname2number
CKPT_DIR = "results/"
class Runner():
"""
Class for running generation on trained checkpoints with conditional adapters.
"""
def __init__(self, model_name, checkpoint_name="1.5B", device="cuda") -> None:
#load the training config to determine how to load the trained model
self.model_name = model_name
self.checkpoint_name = checkpoint_name
self.device = device
#load a model with conditional adapters
self.model_dir = os.path.join(CKPT_DIR, model_name)
ckpt_file = os.path.join(self.model_dir, 'huggingface', checkpoint_name)
if os.path.exists(os.path.join(ckpt_file, "model.safetensors")):
#if there is a local safetensors file to load
self.model = ProgenConditional.from_pretrained(ckpt_file)
self.tokenizer = get_tokenizer()
else:
#download the model from the huggingface model hub and cache it locally
self.model = ProgenConditional.from_pretrained("jsunn-y/ProCALM", subfolder="{}/{}".format(model_name, checkpoint_name), cache_dir=ckpt_file)
self.tokenizer = Tokenizer.from_pretrained("jsunn-y/ProCALM")
self.progenconditional_config = self.model.config
self.model.to(device)
self.model.eval()
self.pad_token_id = PAD_TOKEN_ID
#load the dictionary mapping EC to encoding
self.encoding_dicts = {}
for key, encoding_file in self.progenconditional_config.encoding_files.items():
self.encoding_dicts[key] = torch.load(encoding_file)
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
def sample(
self,
conditions: Dict[str, str] = None, #dictionary mapping the type of condition (ec, tax) to the string specifying the condition
context= '1', #start token
num_return_sequences=45, #effectively the batch size
temperature=0.3, #0.5 and 1 are a bit worse in performance
top_p=0.95, #0.9 or 0.95
max_length=1024,
):
"""
Runs one batch of generation with the specified conditions.
"""
self.temp = 'temp' + str(temperature)
#check if the conditions are in the encoding dicts used to train the model. If not, do unconditional generation for that condition.
self.ec = "no-ec" if "ec" not in self.encoding_dicts.keys() else conditions.get('ec', "no-ec")
self.tax = "no-tax" if "tax" not in self.encoding_dicts.keys() else conditions.get('tax', "no-tax")
condition_encodings = {}
for key, encoding_dict in self.encoding_dicts.items():
condition = conditions.get(key, None)
if condition is not None:
condition_encodings[key] = encoding_dict[condition].to(self.device)
else:
condition_encodings[key] = torch.zeros(1, self.progenconditional_config.encoding_dimensions[key]).to(self.device)
#running things packaged into the huggingface class (alternatively could use beam search instead of probabilistic decoding)
with torch.no_grad():
input_ids = torch.tensor(self.tokenizer.encode(context).ids).view([1, -1]).to(self.device)
tokens_batch = self.model.generate(input_ids=input_ids, condition_encodings=condition_encodings, do_sample=True, temperature=temperature, max_length=max_length, top_p=top_p, num_return_sequences=num_return_sequences, pad_token_id=self.pad_token_id, eos_token_id=4)
as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])]
self.sequences = self.tokenizer.decode_batch(as_lists(tokens_batch))
return self.sequences
def save_seqs(self, sequences):
"""
Saves the list of generated sequences to a fasta file.
"""
#ensure the directory exists
os.makedirs(os.path.join(self.model_dir, 'generated', self.checkpoint_name, self.temp), exist_ok=True)
with open(os.path.join(self.model_dir, 'generated', self.checkpoint_name, self.temp, "sequences_{}_{}.fasta".format(self.ec, self.tax)), "w") as f:
for i, seq in enumerate(sequences):
f.write(f">EC_{self.ec}_tax_{self.tax}_{i}\n")
f.write(f"{seq}\n")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="Model name to load.")
parser.add_argument("--checkpoint", default="latest", type=str, help="Checkpoint name to load.")
parser.add_argument("--ec", default=None, type=str, help="EC number to conditionally generate from. train+test specifies a list of curated ECs.")
parser.add_argument("--tax", default=None, type=str, help="Taxonomy lineage IDS to conditionally generate from")
parser.add_argument("--temp", default=0.3, type=float, help="Temperature for generation")
parser.add_argument("--top_p", default=0.95, type=float, help="Top p for generation")
parser.add_argument("--batch_size", default=45, type=int, help="Batch size for generation.")
parser.add_argument("--num_seqs", default=990, type=int, help="Number of sequences to generate")
args = parser.parse_args()
return args
def main():
args = parse_args()
os.chdir(os.path.dirname(os.path.realpath(__file__)))
use_level1 = False
#reintialize here so the seed is set correctly
runner = Runner(model_name=args.model, checkpoint_name=args.checkpoint)
ec = args.ec
tax = args.tax
all_sequences = []
conditions = {}
if ec is not None:
conditions['ec'] = ec
# if tax is not None:
# assert tax in taxname2number.keys() "Taxonomy must be one of bacteria, archaea, eukaryota, or viruses"
tax = taxname2number[tax] if tax is not None else None
if tax is not None:
conditions['tax'] = tax
tqdm_length = args.num_seqs // args.batch_size
tqdm_iterator = tqdm(range(tqdm_length), desc=f"Generating sequences for EC {ec} and tax {tax}")
for batch in range(args.num_seqs // args.batch_size): #45 is the max batch size that fits on 40GB A100
sequences = runner.sample(conditions=conditions, temperature=args.temp, num_return_sequences=args.batch_size, top_p=args.top_p)
all_sequences.extend(sequences)
runner.save_seqs(all_sequences)
tqdm_iterator.update(1)
if __name__ == "__main__":
main()