The goal of super-resolution (SR) is to recover a high resolution image from a low resolution input, or as they might say on any modern crime show, enhance!
To accomplish this goal, we will be deploying the super-resolution convolution neural network (SRCNN) using Keras. This network was published in the paper, "Image Super-Resolution Using Deep Convolutional Networks" by Chao Dong, et al. in 2014. You can read the full paper at https://arxiv.org/abs/1501.00092.
As the title suggests, the SRCNN is a deep convolutional neural network that learns end-to-end mapping of low resolution to high resolution images. As a result, we can use it to improve the image quality of low resolution images. To evaluate the performance of this network, we will be using three image quality metrics: peak signal to noise ratio (PSNR), mean squared error (MSE), and the structural similarity (SSIM) index.
Furthermore, we will be using OpenCV, the Open Source Computer Vision Library. OpenCV was originally developed by Intel and is used for many real-time computer vision applications. In this particular project, we will be using it to pre and post process our images. As you will see later, we will frequently be converting our images back and forth between the RGB, BGR, and YCrCb color spaces. This is necessary because the SRCNN network was trained on the luminance (Y) channel in the YCrCb color space.
import sys
import keras
import cv2
import numpy
import matplotlib
import skimage
print('Python: {}'.format(sys.version))
print('Keras: {}'.format(keras.__version__))
print('OpenCV: {}'.format(cv2.__version__))
print('NumPy: {}'.format(numpy.__version__))
print('Matplotlib: {}'.format(matplotlib.__version__))
print('Scikit-Image: {}'.format(skimage.__version__))
from keras.models import Sequential
from keras.layers import Conv2D
from keras.optimizers import Adam
from skimage.measure import compare_ssim as ssim
from matplotlib import pyplot as plt
import cv2
import numpy as np
import math
import os
%matplotlib inline
To start, lets define a couple of functions that we can use to calculate the PSNR, MSE, and SSIM. The structural similiarity (SSIM) index was imported directly from the scikit-image library; however, we will have to define our own functions for the PSNR and MSE. Furthermore, we will wrap all three of these metrics into a single function that we can call later.
def psnr(target, ref):
# assume RGB image
target_data = target.astype(float)
ref_data = ref.astype(float)
print(target_data.shape)
print(ref_data.shape)
diff = ref_data - target_data
diff = diff.flatten('C')
rmse = math.sqrt(np.mean(diff ** 2.))
return 20 * math.log10(255. / rmse)
def mse(target, ref):
# the MSE between the two images is the sum of the squared difference between the two images
err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
err /= float(target.shape[0] * target.shape[1])
return err
def compare_images(target, ref):
scores = []
scores.append(psnr(target, ref))
scores.append(mse(target, ref))
scores.append(ssim(target, ref, multichannel =True))
return scores
For this project, we will be using the same images that were used in the original SRCNN paper. We can download these images from http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html. The .zip file identified as the MATLAB code contains the images we want. Copy both the Set5 and Set14 datasets into a new folder called 'source'.
Now that we have some images, we want to produce low resolution versions of these same images. We can accomplish this by resizing the images, both downwards and upwards, using OpeCV. There are several interpolation methods that can be used to resize images; however, we will be using bilinear interpolation.
Once we produce these low resolution images, we can save them in a new folder.
def prepare_images(path, factor):
# loop through the files in the directory
for file in os.listdir(path):
# open the file
img = cv2.imread(path + '/' + file)
# find old and new image dimensions
h, w, _ = img.shape
new_height = int(h / factor)
new_width = int(w / factor)
# resize the image - down
img = cv2.resize(img, (new_width, new_height), interpolation = cv2.INTER_LINEAR)
# resize the image - up
img = cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR)
# save the image
print('Saving {}'.format(file))
cv2.imwrite('images/{}'.format(file), img)
prepare_images('source/', 2)
To ensure that our image quality metrics are being calculated correctly and that the images were effectively degraded, lets calculate the PSNR, MSE, and SSIM between our reference images and the degraded images that we just prepared.
for file in os.listdir('images/'):
# open target and reference images
target = cv2.imread('images/{}'.format(file))
ref = cv2.imread('source/{}'.format(file))
# calculate score
scores = compare_images(target, ref)
# print all three scores with new line characters (\n)
print('{}\nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(file, scores[0], scores[1], scores[2]))
Now that we have our low resolution images and all three image quality metrics functioning properly, we can start building the SRCNN. In Keras, it's as simple as adding layers one after the other. The achitecture and hyper parameters of the SRCNN network can be obtained from the publication referenced above.
def model():
# define model type
SRCNN = Sequential()
# add model layers
SRCNN.add(Conv2D(filters=128, kernel_size = (9, 9), kernel_initializer='glorot_uniform',
activation='relu', padding='valid', use_bias=True, input_shape=(None, None, 1)))
SRCNN.add(Conv2D(filters=64, kernel_size = (3, 3), kernel_initializer='glorot_uniform',
activation='relu', padding='same', use_bias=True))
SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), kernel_initializer='glorot_uniform',
activation='linear', padding='valid', use_bias=True))
# define optimizer
adam = Adam(lr=0.0003)
# compile model
SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
return SRCNN
Now that we have defined our model, we can use it for single-image super-resolution. However, before we do this, we will need to define a couple of image processing functions. Furthermore, it will be necessary to preprocess the images extensively before using them as inputs to the network. This processing will include cropping and color space conversions.
Additionally, to save us the time it takes to train a deep neural network, we will be loading pre-trained weights for the SRCNN. These weights can be found at the following GitHub page: https://github.com/MarkPrecursor/SRCNN-keras
Once we have tested our network, we can perform single-image super-resolution on all of our input images. Furthermore, after processing, we can calculate the PSNR, MSE, and SSIM on the images that we produce. We can save these images directly or create subplots to conveniently display the original, low resolution, and high resolution images side by side.
def modcrop(img, scale):
tmpsz = img.shape
sz = tmpsz[0:2]
sz = sz - np.mod(sz, scale)
img = img[0:sz[0], 1:sz[1]]
return img
def shave(image, border):
img = image[border: -border, border: -border]
return img
def predict(image_path):
# load the srcnn model with weights
srcnn = model()
srcnn.load_weights('3051crop_weight_200.h5')
# load the degraded and reference images
path, file = os.path.split(image_path)
degraded = cv2.imread(image_path)
ref = cv2.imread('source/{}'.format(file))
# preprocess the image with modcrop
ref = modcrop(ref, 3)
degraded = modcrop(degraded, 3)
# convert the image to YCrCb - (srcnn trained on Y channel)
temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
# create image slice and normalize
Y = numpy.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
# perform super-resolution with srcnn
pre = srcnn.predict(Y, batch_size=1)
# post-process output
pre *= 255
pre[pre[:] > 255] = 255
pre[pre[:] < 0] = 0
pre = pre.astype(np.uint8)
# copy Y channel back to image and convert to BGR
temp = shave(temp, 6)
temp[:, :, 0] = pre[0, :, :, 0]
output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
# remove border from reference and degraged image
ref = shave(ref.astype(np.uint8), 6)
degraded = shave(degraded.astype(np.uint8), 6)
# image quality calculations
scores = []
scores.append(compare_images(degraded, ref))
scores.append(compare_images(output, ref))
# return images and scores
return ref, degraded, output, scores
ref, degraded, output, scores = predict('images/flowers.bmp')
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[1][0], scores[1][1], scores[1][2]))
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original')
axs[1].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[1].set_title('Degraded')
axs[2].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[2].set_title('SRCNN')
for ax in axs:
ax.set_xticks([])
ax.set_yticks([])
for file in os.listdir('images'):
# perform super-resolution
ref, degraded, output, scores = predict('images/{}'.format(file))
# display images as subplots
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original')
axs[1].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[1].set_title('Degraded')
axs[1].set(xlabel = 'PSNR: {}\nMSE: {} \nSSIM: {}'.format(scores[0][0], scores[0][1], scores[0][2]))
axs[2].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[2].set_title('SRCNN')
axs[2].set(xlabel = 'PSNR: {} \nMSE: {} \nSSIM: {}'.format(scores[1][0], scores[1][1], scores[1][2]))
# remove the x and y ticks
for ax in axs:
ax.set_xticks([])
ax.set_yticks([])
print('Saving {}'.format(file))
fig.savefig('output/{}.png'.format(os.path.splitext(file)[0]))
plt.close()