-
Notifications
You must be signed in to change notification settings - Fork 0
/
latvian_tweet_sentiment_corpus_run_gpt.py
89 lines (77 loc) · 3 KB
/
latvian_tweet_sentiment_corpus_run_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import csv
import json
import os.path
from loguru import logger
from tqdm import tqdm
from GPT_prompter import Prompter
if os.path.exists('data/ltsc_labeled.json'):
with open('data/ltsc_labeled.json', 'r', encoding='utf8') as f:
dataset = json.load(f)
else:
with open('latvian-tweet-sentiment-corpus/tweet_corpus.json', 'r', encoding='utf8') as f:
dataset = json.load(f)
dataset_by_ground_truth = {}
for d in dataset:
if d['sentiment'] not in dataset_by_ground_truth:
dataset_by_ground_truth[d['sentiment']] = []
dataset_by_ground_truth[d['sentiment']].append(d)
min_sentiment_group = min([len(dataset_by_ground_truth[k]) for k in dataset_by_ground_truth])
for k in dataset_by_ground_truth:
dataset_by_ground_truth[k] = dataset_by_ground_truth[k][:min_sentiment_group]
dataset = []
for k in dataset_by_ground_truth:
dataset += dataset_by_ground_truth[k]
for i, d in enumerate(tqdm(dataset)):
try:
Prompter.label_sample(d)
except Exception as e:
logger.exception(e)
logger.error(f'Failed to get response for {d["text"]}')
with open('data/ltsc_labeled.json', 'w', encoding='utf8') as f:
json.dump(dataset, f, indent=4, ensure_ascii=False)
exit(0)
if i % 10 == 0:
with open('data/ltsc_labeled.json', 'w', encoding='utf8') as f:
json.dump(dataset, f, indent=4, ensure_ascii=False)
prompt_metrics = Prompter.get_metrics(dataset)
# with open('data/ltsc_labeled_metrics.json', 'w') as f:
# json.dump(prompt_metrics, f, indent=4, ensure_ascii=False)
with open('data/ltsc_labeled_metrics.csv', 'w', encoding='utf8', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'prompt',
'prompt_idx',
'prompt_repeat_idx',
'accuracy',
'precision_avg',
'recall_avg',
'f1',
'parsable_response_percent',
'parsable_response_count',
'precision_negative',
'precision_neutral',
'precision_positive',
'recall_negative',
'recall_neutral',
'recall_positive',
'prompt_text'
])
for prompt in prompt_metrics:
writer.writerow([
prompt,
prompt_metrics[prompt]['prompt_idx'],
prompt_metrics[prompt]['prompt_repeat_idx'],
prompt_metrics[prompt]['accuracy'],
prompt_metrics[prompt]['precision_avg'],
prompt_metrics[prompt]['recall_avg'],
prompt_metrics[prompt]['f1'],
prompt_metrics[prompt]['percent_parsable_responses'],
prompt_metrics[prompt]['count_parsable_responses'],
prompt_metrics[prompt]['precision']['negative'],
prompt_metrics[prompt]['precision']['neutral'],
prompt_metrics[prompt]['precision']['positive'],
prompt_metrics[prompt]['recall']['negative'],
prompt_metrics[prompt]['recall']['neutral'],
prompt_metrics[prompt]['recall']['positive'],
prompt_metrics[prompt]['prompt_text']
])