-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcalculate_psnr.py
90 lines (66 loc) · 2.22 KB
/
calculate_psnr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
from tqdm import tqdm
import math
def img_psnr(img1, img2):
# [0,1]
# compute mse
# mse = np.mean((img1-img2)**2)
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
# compute psnr
if mse < 1e-10:
return 100
psnr = 20 * math.log10(1 / math.sqrt(mse))
return psnr
def trans(x):
return x
def calculate_psnr(videos1, videos2, only_final=False):
print("calculate_psnr...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
psnr_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
psnr_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate psnr of a video
psnr_results_of_a_video.append(img_psnr(img1, img2))
psnr_results.append(psnr_results_of_a_video)
psnr_results = np.array(psnr_results)
psnr = []
psnr_std = []
if only_final:
psnr.append(np.mean(psnr_results))
psnr_std.append(np.std(psnr_results))
else:
for clip_timestamp in range(len(video1)):
psnr.append(np.mean(psnr_results[:,clip_timestamp]))
psnr_std.append(np.std(psnr_results[:,clip_timestamp]))
result = {
"value": psnr,
"value_std": psnr_std,
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 30
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
result = calculate_psnr(videos1, videos2)
print("[psnr avg]", result["value"])
print("[psnr std]", result["value_std"])
if __name__ == "__main__":
main()