-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathevaluate.py
116 lines (100 loc) · 4.55 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# coding: utf-8
from __future__ import print_function
import os
import time
import random
from PIL import Image
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
from utils import *
from model import *
from glob import glob
from skimage import color,filters
sess = tf.Session()
input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
[R_decom, I_decom] = DecomNet_simple(input_decom)
decom_output_R = R_decom
decom_output_I = I_decom
output_r = Restoration_net(input_low_r, input_low_i)
output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
saver_Decom = tf.train.Saver(var_list = var_Decom)
saver_adjust = tf.train.Saver(var_list=var_adjust)
saver_restoration = tf.train.Saver(var_list=var_restoration)
decom_checkpoint_dir ='./checkpoint/decom_net_train/'
ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
if ckpt_pre:
print('loaded '+ckpt_pre.model_checkpoint_path)
saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
else:
print('No decomnet checkpoint!')
checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
if ckpt_adjust:
print('loaded '+ckpt_adjust.model_checkpoint_path)
saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
else:
print("No adjust pre model!")
checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
if ckpt:
print('loaded '+ckpt.model_checkpoint_path)
saver_restoration.restore(sess,ckpt.model_checkpoint_path)
else:
print("No restoration pre model!")
###load eval data
eval_low_data = []
eval_img_name =[]
eval_low_data_name = glob('./test/*')
eval_low_data_name.sort()
for idx in range(len(eval_low_data_name)):
[_, name] = os.path.split(eval_low_data_name[idx])
suffix = name[name.find('.') + 1:]
name = name[:name.find('.')]
eval_img_name.append(name)
eval_low_im = load_images(eval_low_data_name[idx])
eval_low_data.append(eval_low_im)
print(eval_low_im.shape)
sample_dir = './results/test/'
if not os.path.isdir(sample_dir):
os.makedirs(sample_dir)
print("Start evalating!")
start_time = time.time()
for idx in range(len(eval_low_data)):
print(idx)
name = eval_img_name[idx]
input_low = eval_low_data[idx]
input_low_eval = np.expand_dims(input_low, axis=0)
h, w, _ = input_low.shape
decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
### change the ratio to get different exposure level, the value can be 0-5.0
ratio = 5.0
i_low_data_ratio = np.ones([h, w])*(ratio)
i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
#The restoration result can find more details from very dark regions, however, it will restore the very dark regions
#with gray colors, we use the following operator to alleviate this weakness.
decom_r_sq = np.squeeze(decom_r_low)
r_gray = color.rgb2gray(decom_r_sq)
r_gray_gaussion = filters.gaussian(r_gray, 3)
low_i = np.minimum((r_gray_gaussion*2)**0.5,1)
low_i_expand_0 = np.expand_dims(low_i, axis = 0)
low_i_expand_3 = np.expand_dims(low_i_expand_0, axis = 3)
result_denoise = restoration_r*low_i_expand_3
fusion4 = result_denoise*adjust_i
#fusion = restoration_r*adjust_i
# fuse with the original input to avoid over-exposure
fusion2 = decom_i_low*input_low_eval + (1-decom_i_low)*fusion4
#print(fusion2.shape)
save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion2)