Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow 2 support #55

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Installation
Requirements: `python3` and `tensorflow`. Tested on Ubuntu 16.04 and Arch Linux. OS X may also work, though not tested.
```
sudo pip3 install tensorflow-gpu opencv-python tifffile scikit-image
sudo pip3 install tensorflow-gpu opencv-python tifffile scikit-image tensorflow_probability tf_slim
git clone https://github.com/yuanming-hu/exposure --recursive
cd exposure
```
Expand Down
28 changes: 14 additions & 14 deletions agent.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
import tensorflow.contrib.layers as ly
import tf_slim as ly
from util import lrelu
import cv2
import math
Expand Down Expand Up @@ -33,7 +33,7 @@ def feature_extractor(net, output_dim, cfg):
print(' ', str(net.get_shape()))
print('before fc: ', net.get_shape()[1])
net = tf.reshape(net, [-1, output_dim])
net = tf.nn.dropout(net, cfg.dropout_keep_prob)
net = tf.nn.dropout(net, rate=1 - (cfg.dropout_keep_prob))
return net


Expand All @@ -56,7 +56,7 @@ def agent_generator(inp, is_train, progress, cfg, high_res=None, alex_in=None):
cfg=cfg)
# filter_features = ly.dropout(filter_features)
for j, filter in enumerate(filters):
with tf.variable_scope('filter_%d' % j):
with tf.compat.v1.variable_scope('filter_%d' % j):
print(' creating filter:', j, 'name:', str(filter.__class__), 'abbr.',
filter.get_short_name())
if not cfg.shared_feature_extractor:
Expand All @@ -77,7 +77,7 @@ def agent_generator(inp, is_train, progress, cfg, high_res=None, alex_in=None):
filtered_images = tf.stack(values=filtered_images, axis=1)
print(' filtered_images:', filtered_images.shape)

with tf.variable_scope('action_selection'):
with tf.compat.v1.variable_scope('action_selection'):
selector_features = feature_extractor(
net=enrich_image_input(cfg, net, states),
output_dim=cfg.feature_extractor_dims,
Expand All @@ -104,29 +104,29 @@ def agent_generator(inp, is_train, progress, cfg, high_res=None, alex_in=None):
pdf = pdf * (1 - cfg.exploration) + cfg.exploration * 1.0 / len(filters)
# pdf = tf.to_float(is_train) * tf.concat([pdf[:, :1], pdf[:, 1:] * states[:, STATE_DROPOUT_BEGIN:]], axis=1) \
# + (1.0 - tf.to_float(is_train)) * pdf
pdf = pdf / (tf.reduce_sum(pdf, axis=1, keep_dims=True) + 1e-30)
entropy = -pdf * tf.log(pdf)
entropy = tf.reduce_sum(entropy, axis=1)[:, None]
pdf = pdf / (tf.reduce_sum(input_tensor=pdf, axis=1, keepdims=True) + 1e-30)
entropy = -pdf * tf.math.log(pdf)
entropy = tf.reduce_sum(input_tensor=entropy, axis=1)[:, None]
print(' pdf:', pdf.shape)
print(' entropy:', entropy.shape)
print(' selection_noise:', selection_noise.shape)
random_filter_id = pdf_sample(pdf, selection_noise)
max_filter_id = tf.cast(tf.argmax(pdf, axis=1), tf.int32)
max_filter_id = tf.cast(tf.argmax(input=pdf, axis=1), tf.int32)
selected_filter_id = is_train * random_filter_id + (
1 - is_train) * max_filter_id
print(' selected_filter_id:', selected_filter_id.shape)
filter_one_hot = tf.one_hot(
selected_filter_id, depth=len(filters), dtype=tf.float32)
print(' filter one_hot', filter_one_hot.shape)
surrogate = tf.reduce_sum(
filter_one_hot * tf.log(pdf + 1e-10), axis=1, keep_dims=True)
input_tensor=filter_one_hot * tf.math.log(pdf + 1e-10), axis=1, keepdims=True)

net = tf.reduce_sum(
filtered_images * filter_one_hot[:, :, None, None, None], axis=1)
input_tensor=filtered_images * filter_one_hot[:, :, None, None, None], axis=1)
if high_res is not None:
high_res_outputs = tf.stack(values=high_res_outputs, axis=1)
high_res_output = tf.reduce_sum(
high_res_outputs * filter_one_hot[:, :, None, None, None], axis=1)
input_tensor=high_res_outputs * filter_one_hot[:, :, None, None, None], axis=1)

# only the first image will get debug_info
debug_info = {
Expand Down Expand Up @@ -228,9 +228,9 @@ def debugger(debug_info, combined=True):
early_stop_penalty = (1 - is_last_step) * submitted * cfg.early_stop_penalty

usage_penalty = tf.reduce_sum(
filter_usage * filter_one_hot[:, regular_filter_start:],
input_tensor=filter_usage * filter_one_hot[:, regular_filter_start:],
axis=1,
keep_dims=True)
keepdims=True)
new_filter_usage = tf.maximum(filter_usage,
filter_one_hot[:, regular_filter_start:])
new_states[STATE_STEP_DIM + 1] = new_filter_usage
Expand All @@ -247,7 +247,7 @@ def debugger(debug_info, combined=True):

# Will be substracted from award
penalty = tf.reduce_mean(
tf.maximum(net - 1, 0)**2, axis=(1, 2, 3)
input_tensor=tf.maximum(net - 1, 0)**2, axis=(1, 2, 3)
)[:,
None] + entropy_penalty + usage_penalty * cfg.filter_usage_penalty + early_stop_penalty

Expand Down
2 changes: 1 addition & 1 deletion config_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def c_lr_callback(t):
cfg.lr_g = g_lr_callback
cfg.lr_c = c_lr_callback

optimizer = lambda lr: tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.9)
optimizer = lambda lr: tf.compat.v1.train.AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.9)

cfg.generator_optimizer = optimizer
cfg.critic_optimizer = optimizer
Expand Down
2 changes: 1 addition & 1 deletion config_sintel.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def c_lr_callback(t):
cfg.lr_g = g_lr_callback
cfg.lr_c = c_lr_callback

optimizer = lambda lr: tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.9)
optimizer = lambda lr: tf.compat.v1.train.AdamOptimizer(learning_rate=lr, beta1=0.5, beta2=0.9)

cfg.generator_optimizer = optimizer
cfg.critic_optimizer = optimizer
Expand Down
16 changes: 8 additions & 8 deletions critics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
import tensorflow.contrib.layers as ly
import tf_slim as ly
from util import lrelu


Expand Down Expand Up @@ -40,26 +40,26 @@ def cnn(net, is_train, cfg):

# Input: float \in [0, 1]
def critic(images, cfg, states=None, is_train=None, reuse=False):
with tf.variable_scope('critic') as scope:
with tf.compat.v1.variable_scope('critic') as scope:
if reuse:
scope.reuse_variables()

if True:
lum = (images[:, :, :, 0] * 0.27 + images[:, :, :, 1] * 0.67 +
images[:, :, :, 2] * 0.06 + 1e-5)[:, :, :]
# luminance and contrast
luminance, contrast = tf.nn.moments(lum, axes=[1, 2])
luminance, contrast = tf.nn.moments(x=lum, axes=[1, 2])

# saturation
i_max = tf.reduce_max(
tf.clip_by_value(images, clip_value_min=0.0, clip_value_max=1.0),
reduction_indices=[3])
input_tensor=tf.clip_by_value(images, clip_value_min=0.0, clip_value_max=1.0),
axis=[3])
i_min = tf.reduce_min(
tf.clip_by_value(images, clip_value_min=0.0, clip_value_max=1.0),
reduction_indices=[3])
input_tensor=tf.clip_by_value(images, clip_value_min=0.0, clip_value_max=1.0),
axis=[3])
sat = (i_max - i_min) / (
tf.minimum(x=i_max + i_min, y=2.0 - i_max - i_min) + 1e-2)
saturation, _ = tf.nn.moments(sat, axes=[1, 2])
saturation, _ = tf.nn.moments(x=sat, axes=[1, 2])

repeatition = 1

Expand Down
16 changes: 8 additions & 8 deletions filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tensorflow as tf
import numpy as np
import tensorflow.contrib.layers as ly
import tf_slim as ly
from util import lrelu, rgb2lum, tanh_range, lerp
import cv2
import math
Expand Down Expand Up @@ -33,13 +33,13 @@ def extract_parameters(self, features):
self.cfg.fc1_size,
scope='fc1',
activation_fn=lrelu,
weights_initializer=tf.contrib.layers.xavier_initializer())
weights_initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
features = ly.fully_connected(
features,
output_dim,
scope='fc2',
activation_fn=None,
weights_initializer=tf.contrib.layers.xavier_initializer())
weights_initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
return features[:, :self.get_num_filter_parameters()], \
features[:, self.get_num_filter_parameters():]

Expand Down Expand Up @@ -113,7 +113,7 @@ def get_mask(self, img, mask_parameters):
return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)
else:
print('* Masking Enabled')
with tf.name_scope(name='mask'):
with tf.compat.v1.name_scope(name='mask'):
# Six parameters for one filter
filter_input_range = 5
assert mask_parameters.shape[1] == self.get_num_mask_parameters()
Expand Down Expand Up @@ -155,7 +155,7 @@ def visualize_mask(self, debug_info, res):
return cv2.resize(
debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),
dsize=res,
interpolation=cv2.cv2.INTER_NEAREST)
interpolation=cv2.INTER_NEAREST)

def draw_high_res_text(self, text, canvas):
cv2.putText(
Expand Down Expand Up @@ -264,7 +264,7 @@ def filter_param_regressor(self, features):
def process(self, img, param):
color_curve = param
# There will be no division by zero here unless the color filter range lower bound is 0
color_curve_sum = tf.reduce_sum(param, axis=4) + 1e-30
color_curve_sum = tf.reduce_sum(input_tensor=param, axis=4) + 1e-30
total_image = img * 0
for i in range(self.cfg.curve_steps):
total_image += tf.clip_by_value(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) * \
Expand Down Expand Up @@ -312,7 +312,7 @@ def filter_param_regressor(self, features):
def process(self, img, param):
# img = tf.minimum(img, 1.0)
tone_curve = param
tone_curve_sum = tf.reduce_sum(tone_curve, axis=4) + 1e-30
tone_curve_sum = tf.reduce_sum(input_tensor=tone_curve, axis=4) + 1e-30
total_image = img * 0
for i in range(self.cfg.curve_steps):
total_image += tf.clip_by_value(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
Expand Down Expand Up @@ -358,7 +358,7 @@ def get_num_mask_parameters(self):
# Closer to 1 values are applied by filter more strongly
# no additional TF variables inside
def get_mask(self, img, mask_parameters):
with tf.name_scope(name='mask'):
with tf.compat.v1.name_scope(name='mask'):
# Five parameters for one filter
filter_input_range = 5
assert mask_parameters.shape[1] == self.get_num_mask_parameters()
Expand Down
2 changes: 1 addition & 1 deletion fivek.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def load(i):
images[i * augmentation_factor + j] = cv2.resize(
new_image,
dsize=(image_size, image_size),
interpolation=cv2.cv2.INTER_AREA)
interpolation=cv2.INTER_AREA)

p.map(load, list(range(len(files))))
print('Data pre-processing finished. Writing....')
Expand Down
144 changes: 72 additions & 72 deletions histogram_intersection.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,72 +1,72 @@
import numpy as np
import cv2
import sys
import os
import random
from util import read_set
HIST_BINS = 32
def hist_intersection(a, b):
return np.minimum(a, b).sum()
def get_statistics(img):
img = np.clip(img, a_min=0.0, a_max=1.0)
HLS = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
lum = img[:, :, 0] * 0.27 + img[:, :, 1] * 0.67 + img[:, :, 2] * 0.06
sat = HLS[:, :, 2].mean()
return [lum.mean(), lum.std() * 2, sat]
def calc_hist(arr, nbins, xrange):
h, _ = np.histogram(a=arr, bins=nbins, range=xrange, density=False)
return h / float(len(arr))
def get_histograms(images):
statistics = np.array(list(zip(*map(get_statistics, images))))
hists = list(map(lambda x: calc_hist(x, HIST_BINS, (0.0, 1.0)), statistics))
return hists, statistics
def read_images(src, tag=None, set=None):
files = os.listdir(src)
images = []
if set is not None:
set = read_set(set)
for f in files:
if tag and f.find(tag) == -1:
continue
if set is not None:
if int(f.split('.')[0]) not in set:
continue
image = (cv2.imread(os.path.join(src, f))[:, :, ::-1] / 255.0).astype(np.float32)
longer_edge = min(image.shape[0], image.shape[1])
for i in range(4):
sx = random.randrange(0, image.shape[0] - longer_edge + 1)
sy = random.randrange(0, image.shape[1] - longer_edge + 1)
new_image = image[sx:sx + longer_edge, sy:sy + longer_edge]
patch = cv2.resize(new_image, dsize=(80, 80), interpolation=cv2.INTER_AREA)
for j in range(4):
target_size = 64
ssx = random.randrange(0, patch.shape[0] - target_size)
ssy = random.randrange(0, patch.shape[1] - target_size)
images.append(patch[ssx:ssx + target_size, ssy:ssy + target_size])
return images
if __name__ == '__main__':
output_src = sys.argv[1]
target_src = sys.argv[2]
output_imgs = read_images(output_src)
target_imgs = read_images(target_src)
output_hists, fake_stats = get_histograms(output_imgs)
target_hists, real_stats = get_histograms(target_imgs)
output_hists, real_hists = np.array(output_hists), np.array(target_hists)
hist_ints = list(map(hist_intersection, output_hists, real_hists))
print('Hist. Inter.: %.2f%% %.2f%% %.2f%%' % (hist_ints[0] * 100, hist_ints[1] * 100, hist_ints[2] * 100))
print(' Avg: %.2f%%' % (sum(hist_ints) / len(hist_ints) * 100))
import numpy as np
import cv2
import sys
import os
import random
from util import read_set

HIST_BINS = 32


def hist_intersection(a, b):
return np.minimum(a, b).sum()


def get_statistics(img):
img = np.clip(img, a_min=0.0, a_max=1.0)
HLS = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
lum = img[:, :, 0] * 0.27 + img[:, :, 1] * 0.67 + img[:, :, 2] * 0.06
sat = HLS[:, :, 2].mean()
return [lum.mean(), lum.std() * 2, sat]


def calc_hist(arr, nbins, xrange):
h, _ = np.histogram(a=arr, bins=nbins, range=xrange, density=False)
return h / float(len(arr))


def get_histograms(images):
statistics = np.array(list(zip(*map(get_statistics, images))))
hists = list(map(lambda x: calc_hist(x, HIST_BINS, (0.0, 1.0)), statistics))
return hists, statistics


def read_images(src, tag=None, set=None):
files = os.listdir(src)
images = []
if set is not None:
set = read_set(set)
for f in files:
if tag and f.find(tag) == -1:
continue
if set is not None:
if int(f.split('.')[0]) not in set:
continue
image = (cv2.imread(os.path.join(src, f))[:, :, ::-1] / 255.0).astype(np.float32)
longer_edge = min(image.shape[0], image.shape[1])
for i in range(4):
sx = random.randrange(0, image.shape[0] - longer_edge + 1)
sy = random.randrange(0, image.shape[1] - longer_edge + 1)
new_image = image[sx:sx + longer_edge, sy:sy + longer_edge]
patch = cv2.resize(new_image, dsize=(80, 80), interpolation=cv2.INTER_AREA)
for j in range(4):
target_size = 64
ssx = random.randrange(0, patch.shape[0] - target_size)
ssy = random.randrange(0, patch.shape[1] - target_size)
images.append(patch[ssx:ssx + target_size, ssy:ssy + target_size])
return images


if __name__ == '__main__':
output_src = sys.argv[1]
target_src = sys.argv[2]

output_imgs = read_images(output_src)
target_imgs = read_images(target_src)

output_hists, fake_stats = get_histograms(output_imgs)
target_hists, real_stats = get_histograms(target_imgs)
output_hists, real_hists = np.array(output_hists), np.array(target_hists)
hist_ints = list(map(hist_intersection, output_hists, real_hists))
print('Hist. Inter.: %.2f%% %.2f%% %.2f%%' % (hist_ints[0] * 100, hist_ints[1] * 100, hist_ints[2] * 100))
print(' Avg: %.2f%%' % (sum(hist_ints) / len(hist_ints) * 100))
Loading