-
Notifications
You must be signed in to change notification settings - Fork 11
/
apply.py
47 lines (33 loc) · 1.21 KB
/
apply.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import get_data
import numpy as np
import torchaudio
def number_of_correct(pred, target):
return pred.squeeze().eq(target).sum().item()
def get_likely_index(tensor):
return tensor.argmax(dim=-1)
def compute_accuracy(model, data_loader, device):
model.eval()
correct = 0
for data, target in data_loader:
data = data.to(device)
target = target.to(device)
pred = model(data)
pred = get_likely_index(pred)
correct += number_of_correct(pred, target)
score = correct / len(data_loader.dataset)
return score
def apply_to_wav(model, waveform: torch.Tensor, sample_rate: float, device: str):
model.eval()
mel_spec = get_data.prepare_wav(waveform, sample_rate)
mel_spec = torch.unsqueeze(mel_spec, dim=0).to(device)
res = model(mel_spec)
probs = torch.nn.Softmax(dim=-1)(res).cpu().detach().numpy()
predictions = []
for idx in np.argsort(-probs):
label = get_data.idx_to_label(idx)
predictions.append((label, probs[idx]))
return predictions
def apply_to_file(model, wav_file: str, device: str):
waveform, sample_rate = torchaudio.load(wav_file)
return apply_to_wav(model, waveform, sample_rate, device)