You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
7.3 KiB

import os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2
def im2uint8(x):
if x.__class__ == tf.Tensor:
return tf.cast(tf.clip_by_value(x, 0.0, 1.0) * 255.0, tf.uint8)
else:
t = np.clip(x, 0.0, 1.0) * 255.0
return t.astype(np.uint8)
class DEBLUR(object):
def __init__(self, args):
self.n_levels = 3
self.scale = 0.5
self.maxH = args.max_height
self.maxW = args.max_width
self.input_path = args.input_path
def generator(self, inputs, reuse=False, scope='g_net'):
def ResnetBlock(x, dim, ksize, scope='rb'):
with tf.variable_scope(scope):
net = slim.conv2d(x, dim, [ksize, ksize], scope='conv1')
net = slim.conv2d(net, dim, [ksize, ksize], activation_fn=None, scope='conv2')
return net
def DenseBlock(x, dim, ksize, scope='db'):
with tf.variable_scope(scope):
net1 = ResnetBlock(x, dim, ksize, scope='d1')
net2 = ResnetBlock(x+net1, dim, ksize, scope='d2')
net3 = ResnetBlock(x+net1+net2, dim, ksize, scope='d3')
net4 = ResnetBlock(x+net1+net2+net3, dim, ksize, scope='d4')
return x+net1+net2+net3+net4
n, h, w, c = inputs.get_shape().as_list()
x_unwrap = []
with tf.variable_scope(scope, reuse=reuse):
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose],
activation_fn=tf.nn.relu, padding='SAME', normalizer_fn=None,
weights_initializer=tf.contrib.layers.xavier_initializer(uniform=True),
biases_initializer=tf.constant_initializer(0.0)):
inp_blur = inputs
inp_pred = inputs
for i in range(self.n_levels):
scale = self.scale ** (self.n_levels - i - 1)
hi = int(round(h * scale))
wi = int(round(w * scale))
inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0)
inp_pred = tf.stop_gradient(tf.image.resize_images(inp_pred, [hi, wi], method=0))
inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp')
# encoder
conv1_1 = slim.conv2d(inp_all, 32, [3, 3], scope='enc1_1_%d' % i)
conv1_2 = DenseBlock(conv1_1, 32, 3, scope='enc1_2')
conv1_3 = DenseBlock(conv1_2, 32, 3, scope='enc1_2')
conv2_1 = slim.conv2d(conv1_3, 64, [3, 3], stride=2, scope='enc2_1_%d' % i)
conv2_2 = DenseBlock(conv2_1, 64, 3, scope='enc2_2')
conv2_3 = DenseBlock(conv2_2, 64, 3, scope='enc2_2')
conv3_1 = slim.conv2d(conv2_3, 128, [3, 3], stride=2, scope='enc3_1_%d' % i)
conv3_2 = DenseBlock(conv3_1, 128, 3, scope='enc3_2')
conv3_3 = DenseBlock(conv3_2, 128, 3, scope='enc3_2')
deconv3_3 = conv3_3
# decoder
deconv3_2 = DenseBlock(deconv3_3, 128, 3, scope='dec3_2')
deconv3_1 = DenseBlock(deconv3_2, 128, 3, scope='dec3_2')
deconv2_3 = slim.conv2d_transpose(deconv3_1, 64, [4, 4], stride=2, scope='dec2_3_%d' % i)
cat2 = deconv2_3 + conv2_3
deconv2_2 = DenseBlock(cat2, 64, 3, scope='dec2_2')
deconv2_1 = DenseBlock(deconv2_2, 64, 3, scope='dec2_2')
deconv1_3 = slim.conv2d_transpose(deconv2_1, 32, [4, 4], stride=2, scope='dec1_3_%d' % i)
cat1 = deconv1_3 + conv1_3
deconv1_2 = DenseBlock(cat1, 32, 3, scope='dec1_2')
deconv1_1 = DenseBlock(deconv1_2, 32, 3, scope='dec1_2')
inp_pred = slim.conv2d(deconv1_1, 1, [3, 3], activation_fn=None, scope='dec1_0_%d' % i)
inp_pred = inp_pred + inp_blur
if i >= 0:
x_unwrap.append(inp_pred)
return x_unwrap
def build(self, model_path):
self.inputs = tf.placeholder(shape=[3, self.maxH, self.maxW, 1], dtype=tf.float32)
self.outputs = self.generator(self.inputs, reuse=tf.AUTO_REUSE)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
self.saver = tf.train.Saver()
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoint_dir = os.path.join(current_dir, model_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'deblur_model'))
def test(self):
input_path = self.input_path
if os.path.isfile(input_path):
mode = 'Image'
else:
mode = 'Folder'
if mode == 'Image':
print(input_path)
res = self.forward(input_path)
output_path = input_path[:-4] + '_res' + input_path[-4:]
cv2.imwrite(output_path, res)
else:
imgs = os.listdir(input_path)
print('Total %d images for deblurring' % len(imgs))
output_path = input_path + '_res'
if not os.path.exists(output_path):
os.makedirs(output_path)
for i in range(len(imgs)):
print(imgs[i])
img_path = os.path.join(input_path, imgs[i])
res = self.forward(img_path)
cv2.imwrite(os.path.join(output_path, imgs[i]), res)
def forward(self, imgpath):
blur = cv2.imread(imgpath, cv2.IMREAD_UNCHANGED).astype('float32')
h, w, c = blur.shape
blur = blur[:,:,::-1]
if (c == 3):
blur = blur[:,:,::-1]
else:
print('Image is not a color image, return the input image!')
return blur
# make sure the width is larger than the height
rot = False
if h > w:
blur = np.transpose(blur,[1,0,2])
rot = True
h = blur.shape[0]
w = blur.shape[1]
H = self.maxH
W = self.maxW
resize = False
if h > H or w > W:
scale = min(1.0 * H / h, 1.0 * W / w)
new_h = int(round(h * scale))
new_w = int(round(w * scale))
print('Original Size:', h, w, 'Resize by scale factor', scale, ' to:', new_h, new_w)
blur = cv2.resize(blur, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
resize = True
blur_pad = np.pad(blur, ((0, H - new_h), (0, W - new_w), (0, 0)), 'edge')
else:
blur_pad = np.pad(blur, ((0, H - h), (0, W - w), (0, 0)), 'edge')
blur_pad = np.expand_dims(blur_pad, 0)
blur_pad = np.transpose(blur_pad, (3,1,2,0))
deblur = self.sess.run(self.outputs, feed_dict={self.inputs: blur_pad/255.0})
res = deblur[-1]
res = np.transpose(res, (3,1,2,0))
res = im2uint8(res[0,:,:,:])
res = res[:,:,::-1]
# crop the image into original size
if resize:
res = res[:new_h,:new_w,:]
res = cv2.resize(res, (w, h), interpolation=cv2.INTER_CUBIC);
else:
res = res[:h,:w,:]
if rot:
res = np.transpose(res,[1,0,2])
res = res[:,:,::-1]
return res