Home page

Deconstructing Deep Learning + ╬┤eviations

Drop me an email | RSS feed link : Click
Format : Date | Title
  TL; DR

Total posts : 78

View My GitHub Profile


Index page

Super resolution

Today we will look at Super Resolution in Python.

Medium link Do you use social media? Ever sent someone a picture and when you look at it later you find that the image is so bad in quality? What if you could reverse that?

What if

The main objective -> Take an image and upsample it by 3x (That means if an image is initally 100x100 by the end we have a 300x300 image) without losing any information. Now if we decided that we want the image to be 100x100 we could resize it but it would be of a much higher resolution than before.

What else could we do

Suppose you do not want to use deep learning for this. What do you do then? Well you have a few options. (None of which I recommend)

  1. Become a Photoshop master (although now even that uses Deep learning)
  2. Go back in time and make sure you took a better picture
  3. Deal with it
  4. (Recommended) Read this post

Where could we use this?

The technical side

Now that we have that out of the way. Let us further define the problem. We will try to explore it from a paper. Shi, Wenzhe, et al. "Real-time single image and video super-resolution using an efficient sub-pixel convolutional neural network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

The technical version of the objective goes something like this. Consider we have two "spaces". One which is of a high resolution, and another which is of a low resolution. Our network should learn how to convert the pixels from the low resolution space to the other one. Efficiently.

To do this we need to follow standard deep learning training procedures along with a bit of modification. (Note that the total code is too long to post here so you can find all of it on my repo)

Get data.

Step 1 as usual would be to get data. For our purpose we can use the BSDS300 dataset. Just get it and extract it.

Now we need to be able to load the data. To do that, we need a few things. - We need to check if the image is a file and load it. We need to perform transforms. And we need to load the data. - Since most of it is just standard code, let us look at the new things only - Center Cropping. To crop an image we need to identify a valid size to crop to.

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

We can call this using CenterCrop(crop_size) in our transforms pipeline.

class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, input_transform = None, target_transform =
                 None):
        super(DatasetFromFolder, self).__init__()

        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir)
                               if is_image_file(x)]
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        input = load_img(self.image_filenames

Create the network

Welcome to the deep end. Our network is actually simple since we are using pytorch. We have 4 conv layers, 4 ReLUs and a special layer called PixelShuffle.

What is PixelShuffle. Well it is the defining moment for the paper we are considering. So defining in fact that PyTorch actually has it inbuilt. In simple terms, this layer is a shuffler. It takes the the tensor of shape H(height) x W(width) x C(channel) . r^2(no of activation) and gives us back a tensor of shape rH x rW x rC. "Shuffle" aka a sub pixel convolutional layer. This is useful because now everything is parallelizable.

We also need to initalize our weights. This is done to make sure that our network starts off on a good foot. We use orthogonal initialization here.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

# Main network
class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, (5,5), (1,1), (2,2))
        self.conv2 = nn.Conv2d(64, 64, (3,3), (1,1), (1,1))
        self.conv3 = nn.Conv2d(64, 32, (3,3), (1,1), (1,1))
        self.conv4 = nn.Conv2d(32, upscale_factor**2, (3,3), (1,1), (1,1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight,init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight,init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

Train!

Wow, we are almost halfway done. Now for the main training. We first load everything we need, import our Net with an upscale factor (aka how many x we need to upscale). We then initalize our Optimizer. We will be using Adam here because it works the best. We also need a loss function.

Well if you think about it, since we are trying to find a pixel wise difference, why not use exactly that.. MSELoss.

We follow the simple strategy of : batch -> push to GPU -> zero the optimizers gradients -> Calculate the loss -> Backprop -> Step through the optimizer. (P.S If there is anything here you do not understand, have a look at my blog)

def train(args, device, train_loader,model,  epoch, optimizer, criterion):
    epoch_loss = 0
    device = torch.device("cuda") # Sending to GPU
    for batch_idx, batch in tqdm(enumerate(train_loader, 1)):
        input, target = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad() # zero gradients
        loss = criterion(model(input), target) # calc loss
        epoch_loss += loss.item()
        loss.backward() #backprop
        optimizer.step()

        print(f"Iteration: {batch_idx}, Loss: {loss} ")
    print(f"Avg epoch_loss: {epoch/len(train_loader)}")

The test loop is very similar to that here so no need to talk about it specifically. We save the model after it is done training.

Run it!

Now that we have a really cool model which works. How about putting it to use?

To do that, let us take an image, preprocess it, convert it to a tensor. After that we can load the model and send this image to the GPU.

img = Image.open(opt.input_image).convert('YCbCr')
y, cb, cr = img.split()

model = torch.load(opt.model)
img_to_tensor = ToTensor()
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])

if opt.cuda:
        model = model.cuda()
        input = input.cuda()

Then we pass it through our model. We get a weird output now. To fix that, we make it a numpy array and perform some operations which will make our image into a proper one. Basically we alter the values so that we have a range in the RGB value range. This will allow us to plot it or do whatever we want.

out = model(input)
out = out.cpu()
out_img_y = out[0].detach().numpy()
out_img_y *= 255.0
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

We can now resize and save the image to wherever we want to.

out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb,
                            out_img_cr]).convert('RGB')

out_img.save(opt.output_filename)
print('output image saved to ', opt.output_filename)

And we are done!

Bonus

I also converted the script to work with videos. :) Go check out the repo

Some thoughts from the paper

Next steps

Thank you for reading this far! Hope you liked the article and it helped you understand the topic better. I would sincerely recommend you read the paper. It is an easy read. Super resolution is something I have been interested in a long time and just today I found an even better network for this task.

Do reach out if you liked the article or have any feedback for future ones. Hope you have a great day!