-
Notifications
You must be signed in to change notification settings - Fork 5
/
create_augments.py
95 lines (66 loc) · 3 KB
/
create_augments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from koeda import EDA, AEDA
import pandas as pd
import numpy as np
import re
from load_data import *
from tqdm import tqdm
from random import random
import random
data = load_data("../dataset/train/train.csv")
##################### 아래 alpha, prob, ratio, punctuations 등을 필요에 맞게 수정하세요 #####################
eda1 = EDA(morpheme_analyzer="Mecab", alpha_sr=0.3, alpha_ri=0.2, alpha_rs=0.1, prob_rd=0.1)
eda2 = EDA(morpheme_analyzer="Okt", alpha_sr=0.3, alpha_ri=0.2, alpha_rs=0.1, prob_rd=0.1)
aeda = AEDA(morpheme_analyzer="Mecab", punc_ratio=0.05, punctuations=["?", "!", ".", ",", "##", "한편", "또한", "결국", "그래서", "따라서", "사실", "그럼에도", "당연하게도", "그리고", "이어서"])
#####################################################################################################
FILE_NAME = 'train_okt.csv'
############################# 위에서 저장할 파일명을 .csv 형식으로 설정해주세요 ################################
#### Following 2 functions for EDA & AEDA ####
def right_augment(en1, en2, text):
return en1[1:-1] in text and en2[1:-1] in text
def apply_augment(func, sentence, en1, en2):
aug_sentence = func(sentence)
if right_augment(en1, en2, aug_sentence):
return aug_sentence
return sentence
#### Following 2 functions for Random Deletion of a sentence ####
def never_delete(e1, e2, word):
e1, e2 = e1[1:-1], e2[1:-1]
include = e1.split() + e2.split()
for inc in include:
if inc in word:
return True
return False
def apply_deletion(text, e1, e2):
lst = text.split()
new_words = []
text = re.sub('[·,]', ' , ', text)
for word in lst:
p = random.uniform(0, 1)
if p > 0.15 or never_delete(e1, e2, word):
new_words.append(word)
aug_text = ' '.join(new_words)
return aug_text
#### Below is where all the augmentations created ###
def augment_sent(df, my_eda1, my_eda2, my_aeda):
sentences = list(df['sentence'])
entities1, entities2 = list(df['subject_entity']), list(df['object_entity'])
original = list(map(lambda sent, ent1, ent2: (sent, ent1, ent2), sentences, entities1, entities2))
aug_sentences = []
for sent, ent1, ent2 in tqdm(original, desc="Count 10 Seconds ..."):
n = random.uniform(0, 1)
if n <= 0.4:
aug_sent = apply_augment(my_eda1, sent, ent1, ent2)
elif n <= 0.9:
aug_sent = apply_augment(my_eda2, sent, ent1, ent2)
elif n <= 0.95:
aug_sent = apply_augment(my_aeda, sent, ent1, ent2)
else:
aug_sent = apply_deletion(sent, ent1, ent2)
aug_sentences.append(aug_sent)
output = pd.DataFrame({'id':df['id'],'sentence':aug_sentences,'subject_entity':entities1,'object_entity':entities2,'label':df['label'],})
return output
def main():
aug_output = augment_sent(data, eda1, eda2, aeda)
aug_output.to_csv("../dataset/train/" + FILE_NAME, index=False)
if __name__ == '__main__':
main()