-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathevaluate_LOLdataset.py
112 lines (95 loc) · 4.37 KB
/
evaluate_LOLdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# coding: utf-8
from __future__ import print_function
import os
import time
import random
from PIL import Image
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
from utils import *
from model import *
from glob import glob
sess = tf.Session()
input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
[R_decom, I_decom] = DecomNet_simple(input_decom)
decom_output_R = R_decom
decom_output_I = I_decom
output_r = Restoration_net(input_low_r, input_low_i)
output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
saver_Decom = tf.train.Saver(var_list = var_Decom)
saver_adjust = tf.train.Saver(var_list=var_adjust)
saver_restoration = tf.train.Saver(var_list=var_restoration)
decom_checkpoint_dir ='./checkpoint/decom_net_train/'
ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
if ckpt_pre:
print('loaded '+ckpt_pre.model_checkpoint_path)
saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
else:
print('No decomnet checkpoint!')
checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
if ckpt_adjust:
print('loaded '+ckpt_adjust.model_checkpoint_path)
saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
else:
print("No adjust pre model!")
checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
if ckpt:
print('loaded '+ckpt.model_checkpoint_path)
saver_restoration.restore(sess,ckpt.model_checkpoint_path)
else:
print("No restoration pre model!")
###load eval data
eval_low_data = []
eval_img_name =[]
eval_low_data_name = glob('./test/eval15/low/*.png')
eval_low_data_name.sort()
for idx in range(len(eval_low_data_name)):
[_, name] = os.path.split(eval_low_data_name[idx])
suffix = name[name.find('.') + 1:]
name = name[:name.find('.')]
eval_img_name.append(name)
eval_low_im = load_images(eval_low_data_name[idx])
eval_low_data.append(eval_low_im)
print(eval_low_im.shape)
# To get better results, the illumination adjustment ratio is computed based on the decom_i_high, so we also need the high data.
eval_high_data = []
eval_high_data_name = glob('./test/eval15/high/*.png')
eval_high_data_name.sort()
for idx in range(len(eval_high_data_name)):
eval_high_im = load_images(eval_high_data_name[idx])
eval_high_data.append(eval_high_im)
sample_dir = './results/LOLdataset_eval15/'
if not os.path.isdir(sample_dir):
os.makedirs(sample_dir)
print("Start evalating!")
start_time = time.time()
for idx in range(len(eval_low_data)):
print(idx)
name = eval_img_name[idx]
input_low = eval_low_data[idx]
input_low_eval = np.expand_dims(input_low, axis=0)
input_high = eval_high_data[idx]
input_high_eval = np.expand_dims(input_high, axis=0)
h, w, _ = input_low.shape
decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
decom_r_high, decom_i_high = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high_eval})
restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
ratio = np.mean(((decom_i_high))/(decom_i_low+0.0001))
i_low_data_ratio = np.ones([h, w])*(ratio)
i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
fusion = restoration_r*adjust_i
save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion)