You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
7.3 KiB
173 lines
7.3 KiB
import os
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow.contrib.slim as slim
|
|
import cv2
|
|
|
|
|
|
def im2uint8(x):
|
|
if x.__class__ == tf.Tensor:
|
|
return tf.cast(tf.clip_by_value(x, 0.0, 1.0) * 255.0, tf.uint8)
|
|
else:
|
|
t = np.clip(x, 0.0, 1.0) * 255.0
|
|
return t.astype(np.uint8)
|
|
|
|
|
|
class DEBLUR(object):
|
|
def __init__(self, args):
|
|
self.n_levels = 3
|
|
self.scale = 0.5
|
|
self.maxH = args.max_height
|
|
self.maxW = args.max_width
|
|
self.input_path = args.input_path
|
|
|
|
|
|
def generator(self, inputs, reuse=False, scope='g_net'):
|
|
def ResnetBlock(x, dim, ksize, scope='rb'):
|
|
with tf.variable_scope(scope):
|
|
net = slim.conv2d(x, dim, [ksize, ksize], scope='conv1')
|
|
net = slim.conv2d(net, dim, [ksize, ksize], activation_fn=None, scope='conv2')
|
|
return net
|
|
|
|
def DenseBlock(x, dim, ksize, scope='db'):
|
|
with tf.variable_scope(scope):
|
|
net1 = ResnetBlock(x, dim, ksize, scope='d1')
|
|
net2 = ResnetBlock(x+net1, dim, ksize, scope='d2')
|
|
net3 = ResnetBlock(x+net1+net2, dim, ksize, scope='d3')
|
|
net4 = ResnetBlock(x+net1+net2+net3, dim, ksize, scope='d4')
|
|
return x+net1+net2+net3+net4
|
|
n, h, w, c = inputs.get_shape().as_list()
|
|
|
|
x_unwrap = []
|
|
|
|
with tf.variable_scope(scope, reuse=reuse):
|
|
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose],
|
|
activation_fn=tf.nn.relu, padding='SAME', normalizer_fn=None,
|
|
weights_initializer=tf.contrib.layers.xavier_initializer(uniform=True),
|
|
biases_initializer=tf.constant_initializer(0.0)):
|
|
|
|
inp_blur = inputs
|
|
inp_pred = inputs
|
|
for i in range(self.n_levels):
|
|
scale = self.scale ** (self.n_levels - i - 1)
|
|
hi = int(round(h * scale))
|
|
wi = int(round(w * scale))
|
|
inp_blur = tf.image.resize_images(inputs, [hi, wi], method=0)
|
|
inp_pred = tf.stop_gradient(tf.image.resize_images(inp_pred, [hi, wi], method=0))
|
|
inp_all = tf.concat([inp_blur, inp_pred], axis=3, name='inp')
|
|
|
|
# encoder
|
|
conv1_1 = slim.conv2d(inp_all, 32, [3, 3], scope='enc1_1_%d' % i)
|
|
conv1_2 = DenseBlock(conv1_1, 32, 3, scope='enc1_2')
|
|
conv1_3 = DenseBlock(conv1_2, 32, 3, scope='enc1_2')
|
|
conv2_1 = slim.conv2d(conv1_3, 64, [3, 3], stride=2, scope='enc2_1_%d' % i)
|
|
conv2_2 = DenseBlock(conv2_1, 64, 3, scope='enc2_2')
|
|
conv2_3 = DenseBlock(conv2_2, 64, 3, scope='enc2_2')
|
|
conv3_1 = slim.conv2d(conv2_3, 128, [3, 3], stride=2, scope='enc3_1_%d' % i)
|
|
conv3_2 = DenseBlock(conv3_1, 128, 3, scope='enc3_2')
|
|
conv3_3 = DenseBlock(conv3_2, 128, 3, scope='enc3_2')
|
|
|
|
deconv3_3 = conv3_3
|
|
|
|
# decoder
|
|
deconv3_2 = DenseBlock(deconv3_3, 128, 3, scope='dec3_2')
|
|
deconv3_1 = DenseBlock(deconv3_2, 128, 3, scope='dec3_2')
|
|
deconv2_3 = slim.conv2d_transpose(deconv3_1, 64, [4, 4], stride=2, scope='dec2_3_%d' % i)
|
|
cat2 = deconv2_3 + conv2_3
|
|
deconv2_2 = DenseBlock(cat2, 64, 3, scope='dec2_2')
|
|
deconv2_1 = DenseBlock(deconv2_2, 64, 3, scope='dec2_2')
|
|
deconv1_3 = slim.conv2d_transpose(deconv2_1, 32, [4, 4], stride=2, scope='dec1_3_%d' % i)
|
|
cat1 = deconv1_3 + conv1_3
|
|
deconv1_2 = DenseBlock(cat1, 32, 3, scope='dec1_2')
|
|
deconv1_1 = DenseBlock(deconv1_2, 32, 3, scope='dec1_2')
|
|
inp_pred = slim.conv2d(deconv1_1, 1, [3, 3], activation_fn=None, scope='dec1_0_%d' % i)
|
|
|
|
inp_pred = inp_pred + inp_blur
|
|
|
|
if i >= 0:
|
|
x_unwrap.append(inp_pred)
|
|
|
|
return x_unwrap
|
|
|
|
|
|
def build(self, model_path):
|
|
self.inputs = tf.placeholder(shape=[3, self.maxH, self.maxW, 1], dtype=tf.float32)
|
|
self.outputs = self.generator(self.inputs, reuse=tf.AUTO_REUSE)
|
|
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
|
|
self.saver = tf.train.Saver()
|
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
checkpoint_dir = os.path.join(current_dir, model_path)
|
|
self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'deblur_model'))
|
|
|
|
|
|
def test(self):
|
|
input_path = self.input_path
|
|
if os.path.isfile(input_path):
|
|
mode = 'Image'
|
|
else:
|
|
mode = 'Folder'
|
|
if mode == 'Image':
|
|
print(input_path)
|
|
res = self.forward(input_path)
|
|
output_path = input_path[:-4] + '_res' + input_path[-4:]
|
|
cv2.imwrite(output_path, res)
|
|
else:
|
|
imgs = os.listdir(input_path)
|
|
print('Total %d images for deblurring' % len(imgs))
|
|
output_path = input_path + '_res'
|
|
if not os.path.exists(output_path):
|
|
os.makedirs(output_path)
|
|
for i in range(len(imgs)):
|
|
print(imgs[i])
|
|
img_path = os.path.join(input_path, imgs[i])
|
|
res = self.forward(img_path)
|
|
cv2.imwrite(os.path.join(output_path, imgs[i]), res)
|
|
|
|
|
|
def forward(self, imgpath):
|
|
blur = cv2.imread(imgpath, cv2.IMREAD_UNCHANGED).astype('float32')
|
|
h, w, c = blur.shape
|
|
blur = blur[:,:,::-1]
|
|
if (c == 3):
|
|
blur = blur[:,:,::-1]
|
|
else:
|
|
print('Image is not a color image, return the input image!')
|
|
return blur
|
|
# make sure the width is larger than the height
|
|
rot = False
|
|
if h > w:
|
|
blur = np.transpose(blur,[1,0,2])
|
|
rot = True
|
|
h = blur.shape[0]
|
|
w = blur.shape[1]
|
|
H = self.maxH
|
|
W = self.maxW
|
|
resize = False
|
|
if h > H or w > W:
|
|
scale = min(1.0 * H / h, 1.0 * W / w)
|
|
new_h = int(round(h * scale))
|
|
new_w = int(round(w * scale))
|
|
print('Original Size:', h, w, 'Resize by scale factor', scale, ' to:', new_h, new_w)
|
|
blur = cv2.resize(blur, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
|
resize = True
|
|
blur_pad = np.pad(blur, ((0, H - new_h), (0, W - new_w), (0, 0)), 'edge')
|
|
else:
|
|
blur_pad = np.pad(blur, ((0, H - h), (0, W - w), (0, 0)), 'edge')
|
|
blur_pad = np.expand_dims(blur_pad, 0)
|
|
blur_pad = np.transpose(blur_pad, (3,1,2,0))
|
|
|
|
deblur = self.sess.run(self.outputs, feed_dict={self.inputs: blur_pad/255.0})
|
|
res = deblur[-1]
|
|
res = np.transpose(res, (3,1,2,0))
|
|
res = im2uint8(res[0,:,:,:])
|
|
res = res[:,:,::-1]
|
|
# crop the image into original size
|
|
if resize:
|
|
res = res[:new_h,:new_w,:]
|
|
res = cv2.resize(res, (w, h), interpolation=cv2.INTER_CUBIC);
|
|
else:
|
|
res = res[:h,:w,:]
|
|
if rot:
|
|
res = np.transpose(res,[1,0,2])
|
|
res = res[:,:,::-1]
|
|
return res |