forked from xuchenglin28/speaker_extraction_SpEx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_sisdr.py
70 lines (60 loc) · 1.91 KB
/
compute_sisdr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
"""
Compute SI-SDR as the evaluation metric
"""
import argparse
from tqdm import tqdm
from collections import defaultdict
from utils.sisdr import sisdr
from utils.audio import WaveReader, Reader
class Report(object):
def __init__(self, spk2gender=None):
self.s2g = Reader(spk2gender) if spk2gender else None
self.snr = defaultdict(float)
self.cnt = defaultdict(int)
def add(self, key, val):
gender = "NG"
if self.s2g:
gender = self.s2g[key]
self.snr[gender] += val
self.cnt[gender] += 1
def report(self):
print("SI-SDR(dB) Report: ")
for gender in self.snr:
tot_snrs = self.snr[gender]
num_utts = self.cnt[gender]
print("{}: {:d}/{:.3f}".format(gender, num_utts,
tot_snrs / num_utts))
def run(args):
reporter = Report(args.spk2gender)
sep_reader = WaveReader(args.sep_scp)
ref_reader = WaveReader(args.ref_scp)
for key, sep in tqdm(sep_reader):
ref = ref_reader[key]
if sep.size != ref.size:
end = min(sep.size, ref.size)
sep = sep[:end]
ref = ref[:end]
snr = sisdr(sep, ref)
reporter.add(key, snr)
reporter.report()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=
"Command to compute SI-SDR, as metric of the separation quality",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--sep_scp",
type=str,
help="Separated speech list, egs: spk1.scp")
parser.add_argument(
"--ref_scp",
type=str,
help="Reference speech list, egs: ref.scp")
parser.add_argument(
"--spk2gender",
type=str,
default="",
help="If assigned, report results per gender")
args = parser.parse_args()
run(args)