| | import os |
| | from PIL import Image, ImageFile |
| | import torch |
| | from torch.utils.data import Dataset |
| | import torchvision.transforms as transforms |
| | import matplotlib.pyplot as plt |
| | from pathlib import Path |
| | from glob import glob |
| |
|
| | def adaptive_instance_normalization(x, y, eps=1e-5): |
| | """ |
| | Adaptive Instance Normalization. Perform neural style transfer given content image x |
| | and style image y. |
| | |
| | Args: |
| | x (torch.FloatTensor): Content image tensor |
| | y (torch.FloatTensor): Style image tensor |
| | eps (float, default=1e-5): Small value to avoid zero division |
| | |
| | Return: |
| | output (torch.FloatTensor): AdaIN style transferred output |
| | """ |
| |
|
| | mu_x = torch.mean(x, dim=[2, 3]) |
| | mu_y = torch.mean(y, dim=[2, 3]) |
| | mu_x = mu_x.unsqueeze(-1).unsqueeze(-1) |
| | mu_y = mu_y.unsqueeze(-1).unsqueeze(-1) |
| |
|
| | sigma_x = torch.std(x, dim=[2, 3]) |
| | sigma_y = torch.std(y, dim=[2, 3]) |
| | sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps |
| | sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps |
| |
|
| | return (x - mu_x) / sigma_x * sigma_y + mu_y |
| |
|
| | def transform(size): |
| | """ |
| | Image preprocess transformation. Resize image and convert to tensor. |
| | |
| | Args: |
| | size (int): Resize image size |
| | |
| | Return: |
| | output (torchvision.transforms): Composition of torchvision.transforms steps |
| | """ |
| | |
| | t = [] |
| | t.append(transforms.Resize(size)) |
| | t.append(transforms.ToTensor()) |
| | t = transforms.Compose(t) |
| | return t |
| |
|
| | def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'): |
| | """ |
| | Generate and save an image that contains row x col grids of images. |
| | |
| | Args: |
| | row (int): number of rows |
| | col (int): number of columns |
| | images (list of PIL image): list of images. |
| | height (int) : height of each image (inch) |
| | width (int) : width of eac image (inch) |
| | save_pth (str): save file path |
| | """ |
| |
|
| | width = col * width |
| | height = row * height |
| | plt.figure(figsize=(width, height)) |
| | for i, image in enumerate(images): |
| | plt.subplot(row, col, i+1) |
| | plt.imshow(image) |
| | plt.axis('off') |
| | plt.subplots_adjust(wspace=0.01, hspace=0.01) |
| | plt.savefig(save_pth) |
| |
|
| |
|
| | def linear_histogram_matching(content_tensor, style_tensor): |
| | """ |
| | Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor. |
| | |
| | Args: |
| | content_tensor (torch.FloatTensor): Content image |
| | style_tensor (torch.FloatTensor): Style Image |
| | |
| | Return: |
| | style_tensor (torch.FloatTensor): histogram matched Style Image |
| | """ |
| | |
| | for b in range(len(content_tensor)): |
| | std_ct = [] |
| | std_st = [] |
| | mean_ct = [] |
| | mean_st = [] |
| | |
| | for c in range(len(content_tensor[b])): |
| | std_ct.append(torch.var(content_tensor[b][c],unbiased = False)) |
| | mean_ct.append(torch.mean(content_tensor[b][c])) |
| | std_st.append(torch.var(style_tensor[b][c],unbiased = False)) |
| | mean_st.append(torch.mean(style_tensor[b][c])) |
| | style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c] |
| | return style_tensor |
| |
|
| |
|
| | class TrainSet(Dataset): |
| | """ |
| | Build Training dataset |
| | """ |
| | def __init__(self, content_dir, style_dir, crop_size = 256): |
| | super().__init__() |
| |
|
| | self.content_files = [Path(f) for f in glob(content_dir+'/*')] |
| | self.style_files = [Path(f) for f in glob(style_dir+'/*')] |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC), |
| | transforms.RandomCrop(crop_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) |
| | ]) |
| |
|
| | Image.MAX_IMAGE_PIXELS = None |
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| | |
| | def __len__(self): |
| | return min(len(self.style_files), len(self.content_files)) |
| |
|
| | def __getitem__(self, index): |
| | content_img = Image.open(self.content_files[index]).convert('RGB') |
| | style_img = Image.open(self.style_files[index]).convert('RGB') |
| | |
| | content_sample = self.transform(content_img) |
| | style_sample = self.transform(style_img) |
| |
|
| | return content_sample, style_sample |
| |
|
| | class Range(object): |
| | """ |
| | Helper class for input argument range restriction |
| | """ |
| | def __init__(self, start, end): |
| | self.start = start |
| | self.end = end |
| | def __eq__(self, other): |
| | return self.start <= other <= self.end |