-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtranslate.py
50 lines (40 loc) · 1.69 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import json
from tqdm import tqdm
import argparse
lang_codes = {
"de": "de_DE",
"zh": "zh_CN"
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--src", required=True, type=str)
parser.add_argument("--model", required=True, type=str)
parser.add_argument("--output", required=True, type=str)
parser.add_argument("--target_lang", required=True, type=str)
parser.add_argument("--add_context", action="store_true")
parser.add_argument("--device", required=True, type=str)
parser.add_argument("--shortname", required=True, type=str)
return parser.parse_args()
def main():
args = parse_args()
samples = json.load(open(args.src))
model = MBartForConditionalGeneration.from_pretrained(args.model).to(args.device)
tokenizer = MBart50TokenizerFast.from_pretrained(args.model, src_lang="en_XX")
translations = []
for sample in tqdm(samples):
src = sample["ctx_src"] + " <brk> " + sample["src"] if args.add_context else sample["src"]
inputs = tokenizer(src, return_tensors = "pt").to(args.device)
outputs = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[lang_codes.get(args.target_lang)])
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
if not sample["translations"]:
sample["translations"] = {}
sample["translations"][args.shortname] = {
"src": src,
"tgt": sample["tgt"],
"translation": translation
}
with open(args.output, "w") as output:
json.dump(samples, output)
if __name__ == "__main__":
main()