-
Notifications
You must be signed in to change notification settings - Fork 155
/
Copy pathprott5_embedder.py
187 lines (147 loc) · 7.46 KB
/
prott5_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 23 18:33:22 2020
@author: mheinzinger
"""
import argparse
import time
from pathlib import Path
import torch
import h5py
from transformers import T5EncoderModel, T5Tokenizer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))
def get_T5_model(model_dir, transformer_link = "Rostlab/prot_t5_xl_half_uniref50-enc"):
print("Loading: {}".format(transformer_link))
if model_dir is not None:
print("##########################")
print("Loading cached model from: {}".format(model_dir))
print("##########################")
model = T5EncoderModel.from_pretrained(transformer_link, cache_dir=model_dir)
# only cast to full-precision if no GPU is available
if device==torch.device("cpu"):
print("Casting model to full precision for running on CPU ...")
model.to(torch.float32)
model = model.to(device)
model = model.eval()
vocab = T5Tokenizer.from_pretrained(transformer_link, do_lower_case=False )
return model, vocab
def read_fasta( fasta_path ):
'''
Reads in fasta file containing multiple sequences.
Returns dictionary of holding multiple sequences or only single
sequence, depending on input file.
'''
sequences = dict()
with open( fasta_path, 'r' ) as fasta_f:
for line in fasta_f:
# get uniprot ID from header and create new entry
if line.startswith('>'):
uniprot_id = line.replace('>', '').strip()
# replace tokens that are mis-interpreted when loading h5
uniprot_id = uniprot_id.replace("/","_").replace(".","_")
sequences[ uniprot_id ] = ''
else:
# repl. all whie-space chars and join seqs spanning multiple lines
sequences[ uniprot_id ] += ''.join( line.split() ).upper().replace("-","") # drop gaps and cast to upper-case
return sequences
def get_embeddings( seq_path,
emb_path,
model_dir,
per_protein, # whether to derive per-protein (mean-pooled) embeddings
max_residues=4000, # number of cumulative residues per batch
max_seq_len=1000, # max length after which we switch to single-sequence processing to avoid OOM
max_batch=100 # max number of sequences per single batch
):
seq_dict = dict()
emb_dict = dict()
# Read in fasta
seq_dict = read_fasta( seq_path )
model, vocab = get_T5_model(model_dir)
print('########################################')
print('Example sequence: {}\n{}'.format( next(iter(
seq_dict.keys())), next(iter(seq_dict.values()))) )
print('########################################')
print('Total number of sequences: {}'.format(len(seq_dict)))
avg_length = sum([ len(seq) for _, seq in seq_dict.items()]) / len(seq_dict)
n_long = sum([ 1 for _, seq in seq_dict.items() if len(seq)>max_seq_len])
seq_dict = sorted( seq_dict.items(), key=lambda kv: len( seq_dict[kv[0]] ), reverse=True )
print("Average sequence length: {}".format(avg_length))
print("Number of sequences >{}: {}".format(max_seq_len, n_long))
start = time.time()
batch = list()
for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
seq = seq.replace('U','X').replace('Z','X').replace('O','X')
seq_len = len(seq)
seq = ' '.join(list(seq))
batch.append((pdb_id,seq,seq_len))
# count residues in current batch and add the last sequence length to
# avoid that batches with (n_res_batch > max_residues) get processed
n_res_batch = sum([ s_len for _, _, s_len in batch ]) + seq_len
if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
pdb_ids, seqs, seq_lens = zip(*batch)
batch = list()
token_encoding = vocab.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(token_encoding['input_ids']).to(device)
attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
try:
with torch.no_grad():
embedding_repr = model(input_ids, attention_mask=attention_mask)
except RuntimeError:
print("RuntimeError during embedding for {} (L={}). Try lowering batch size. ".format(pdb_id, seq_len) +
"If single sequence processing does not work, you need more vRAM to process your protein.")
continue
# batch-size x seq_len x embedding_dim
# extra token is added at the end of the seq
for batch_idx, identifier in enumerate(pdb_ids):
s_len = seq_lens[batch_idx]
# slice-off padded/special tokens
emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
if per_protein:
emb = emb.mean(dim=0)
if len(emb_dict) == 0:
print("Embedded protein {} with length {} to emb. of shape: {}".format(
identifier, s_len, emb.shape))
emb_dict[ identifier ] = emb.detach().cpu().numpy().squeeze()
end = time.time()
with h5py.File(str(emb_path), "w") as hf:
for sequence_id, embedding in emb_dict.items():
# noinspection PyUnboundLocalVariable
hf.create_dataset(sequence_id, data=embedding)
print('\n############# STATS #############')
print('Total number of embeddings: {}'.format(len(emb_dict)))
print('Total time: {:.2f}[s]; time/prot: {:.4f}[s]; avg. len= {:.2f}'.format(
end-start, (end-start)/len(emb_dict), avg_length))
return True
def create_arg_parser():
""""Creates and returns the ArgumentParser object."""
# Instantiate the parser
parser = argparse.ArgumentParser(description=(
't5_embedder.py creates T5 embeddings for a given text '+
' file containing sequence(s) in FASTA-format.') )
# Required positional argument
parser.add_argument( '-i', '--input', required=True, type=str,
help='A path to a fasta-formatted text file containing protein sequence(s).')
# Optional positional argument
parser.add_argument( '-o', '--output', required=True, type=str,
help='A path for saving the created embeddings as NumPy npz file.')
# Required positional argument
parser.add_argument('--model', required=False, type=str,
default=None,
help='A path to a directory holding the checkpoint for a pre-trained model' )
# Optional argument
parser.add_argument('--per_protein', type=int,
default=0,
help="Whether to return per-residue embeddings (0: default) or the mean-pooled per-protein representation (1).")
return parser
def main():
parser = create_arg_parser()
args = parser.parse_args()
seq_path = Path( args.input )
emb_path = Path( args.output)
model_dir = Path( args.model ) if args.model is not None else None
per_protein = False if int(args.per_protein)==0 else True
get_embeddings( seq_path, emb_path, model_dir, per_protein=per_protein )
if __name__ == '__main__':
main()