forked from ronuchit/GLIB-AAAI-2021
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
65 lines (60 loc) · 2.29 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
sns.set(style="darkgrid")
def smooth_curve(x, y):
halfwidth = int(np.ceil(len(x) / 50)) # Halfwidth of our smoothing convolution
k = halfwidth
xsmoo = x
ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1),
mode='same')
return xsmoo, ysmoo
def plot_results(domain_name, learning_name, all_results, outdir="results",
smooth=False, dist=False):
"""Results are lists of single-run result lists, across different
random seeds.
"""
outdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), outdir)
outfile = os.path.join(outdir, "{}_{}_{}.png".format(
domain_name, learning_name, "dist" if dist else "succ"))
plt.figure()
if dist:
ylabel = "Test Set Average Variational Distance"
else:
ylabel = "Test Set Success Rate"
plt.ylabel(ylabel)
for curiosity_module in sorted(all_results):
results = np.array(all_results[curiosity_module])
if len(results) == 0:
continue
label = curiosity_module
xs = results[0, :, 0]
if dist:
ys = results[:, :, 2]
else:
ys = results[:, :, 1]
results_mean = np.mean(ys, axis=0)
# results_std = np.std(ys, axis=0)
if smooth:
xs, results_mean = smooth_curve(xs, results_mean)
# _, results_std = smooth_curve(xs, results_std)
plt.plot(xs, results_mean, label=label.replace("_", " "))
# plt.fill_between(xs, results_mean+results_std,
# results_mean-results_std, alpha=0.2)
min_seeds = min(len(x) for x in all_results.values())
max_seeds = max(len(x) for x in all_results.values())
if min_seeds == max_seeds:
title = "{} Domain, {} Learner ({} seeds)".format(
domain_name, learning_name, min_seeds)
else:
title = "{} Domain, {} Learner ({} to {} seeds)".format(
domain_name, learning_name, min_seeds, max_seeds)
if smooth:
title += " [smoothed]"
plt.title(title)
plt.ylim((-0.1, 1.1))
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(outfile, dpi=300)
print("Wrote out to {}".format(outfile))