-
Notifications
You must be signed in to change notification settings - Fork 5
/
translate_entity.py
91 lines (76 loc) · 2.86 KB
/
translate_entity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pandas as pd
from deep_translator import GoogleTranslator
from load_data import *
from tqdm import tqdm
import argparse
def google_translate(text, tgt="en"):
try:
if tgt == "en":
translator = GoogleTranslator(source='auto', target="en")
new_txt = translator.translate(text)
elif tgt == "ko":
translator1 = GoogleTranslator(source='auto', target="en")
translator2 = GoogleTranslator(source='auto', target="ko")
mid_txt = translator1.translate(text)
new_txt = translator2.translate(mid_txt)
if new_txt:
# print(new_txt)
return new_txt
# print(text)
return text
except:
print('Problem Occurred!! (Maybe Minor One, Do Not Worry)')
return text
def check_translate(word, sent):
if word not in sent:
# 영어가 아니라면 = 만약 한국어라면
if word.encode().isalpha() is False:
new_word = google_translate(word, "en")
# print(word, "en")
# 영어라면
else:
new_word = google_translate(word, "ko")
# print(word, "ko")
if new_word not in sent:
new_word = word
return new_word
else:
return word
def modify_entity(sub, ob, sent):
if sub[0] == "'":
sub = sub[1:-1]
ob = ob[1:-1]
new_sub = check_translate(sub, sent)
new_ob = check_translate(ob, sent)
new_sub = "'" + new_sub + "'"
new_ob = "'" + new_ob + "'"
# print(new_sub, new_ob)
return new_sub, new_ob
parser = argparse.ArgumentParser()
parser.add_argument('--FILE', type=str, default="train_trans.csv")
parser.add_argument('--SAVE', type=str, default="../dataset/train/translate/")
parser.add_argument('--START_IDX', type=int, default=0)
parser.add_argument('--END_IDX', type=int, default=5000)
args = parser.parse_args()
file_dir = "../dataset/train/" + args.FILE
cur_df = pd.read_csv(file_dir)
start_idx = args.START_IDX
end_idx = args.END_IDX if args.END_IDX < len(cur_df) else len(cur_df)
save_dir = args.SAVE
subjects = list(cur_df['subject_entity'])[start_idx:end_idx]
objects = list(cur_df['object_entity'])[start_idx:end_idx]
sentences = list(cur_df['sentence'])[start_idx:end_idx]
labels = list(cur_df['label'])[start_idx:end_idx]
ids = list(cur_df['id'])[start_idx:end_idx]
new_df = pd.DataFrame()
for i in tqdm(range(len(ids)), desc="Modifying Entities..."):
if labels[i] != 'no_relation':
mod_sub, mod_ob = modify_entity(subjects[i], objects[i], sentences[i])
subjects[i], objects[i] = mod_sub, mod_ob
new_df['id'] = ids
new_df['sentence'] = sentences
new_df['subject_entity'] = subjects
new_df['object_entity'] = objects
new_df['label'] = labels
new_df.to_csv(save_dir + f'train_trans{start_idx}_{end_idx}.csv', index=False)
print(f'train_trans{start_idx}_{end_idx}.csv entities modified!!')