-
Notifications
You must be signed in to change notification settings - Fork 121
/
Copy pathdatasets.py
81 lines (68 loc) · 3.6 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from utils import ImageTransforms
class SRDataset(Dataset):
"""
A PyTorch Dataset to be used by a PyTorch DataLoader.
"""
def __init__(self, data_folder, split, crop_size, scaling_factor, lr_img_type, hr_img_type, test_data_name=None):
"""
:param data_folder: # folder with JSON data files
:param split: one of 'train' or 'test'
:param crop_size: crop size of target HR images
:param scaling_factor: the input LR images will be downsampled from the target HR images by this factor; the scaling done in the super-resolution
:param lr_img_type: the format for the LR image supplied to the model; see convert_image() in utils.py for available formats
:param hr_img_type: the format for the HR image supplied to the model; see convert_image() in utils.py for available formats
:param test_data_name: if this is the 'test' split, which test dataset? (for example, "Set14")
"""
self.data_folder = data_folder
self.split = split.lower()
self.crop_size = int(crop_size)
self.scaling_factor = int(scaling_factor)
self.lr_img_type = lr_img_type
self.hr_img_type = hr_img_type
self.test_data_name = test_data_name
assert self.split in {'train', 'test'}
if self.split == 'test' and self.test_data_name is None:
raise ValueError("Please provide the name of the test dataset!")
assert lr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}
assert hr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}
# If this is a training dataset, then crop dimensions must be perfectly divisible by the scaling factor
# (If this is a test dataset, images are not cropped to a fixed size, so this variable isn't used)
if self.split == 'train':
assert self.crop_size % self.scaling_factor == 0, "Crop dimensions are not perfectly divisible by scaling factor! This will lead to a mismatch in the dimensions of the original HR patches and their super-resolved (SR) versions!"
# Read list of image-paths
if self.split == 'train':
with open(os.path.join(data_folder, 'train_images.json'), 'r') as j:
self.images = json.load(j)
else:
with open(os.path.join(data_folder, self.test_data_name + '_test_images.json'), 'r') as j:
self.images = json.load(j)
# Select the correct set of transforms
self.transform = ImageTransforms(split=self.split,
crop_size=self.crop_size,
scaling_factor=self.scaling_factor,
lr_img_type=self.lr_img_type,
hr_img_type=self.hr_img_type)
def __getitem__(self, i):
"""
This method is required to be defined for use in the PyTorch DataLoader.
:param i: index to retrieve
:return: the 'i'th pair LR and HR images to be fed into the model
"""
# Read image
img = Image.open(self.images[i], mode='r')
img = img.convert('RGB')
if img.width <= 96 or img.height <= 96:
print(self.images[i], img.width, img.height)
lr_img, hr_img = self.transform(img)
return lr_img, hr_img
def __len__(self):
"""
This method is required to be defined for use in the PyTorch DataLoader.
:return: size of this data (in number of images)
"""
return len(self.images)