Skip to content

Commit 8955a88

Browse files
author
xunchao li
committed
first commit
0 parents  commit 8955a88

32 files changed

+3734
-0
lines changed

.DS_Store

8 KB
Binary file not shown.

data_new/__init__.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from importlib import import_module
2+
#from dataloader import MSDataLoader
3+
from torch.utils.data import dataloader
4+
from torch.utils.data import ConcatDataset
5+
6+
# This is a simple wrapper function for ConcatDataset
7+
class MyConcatDataset(ConcatDataset):
8+
def __init__(self, datasets):
9+
super(MyConcatDataset, self).__init__(datasets)
10+
self.train = datasets[0].train
11+
12+
def set_scale(self, idx_scale):
13+
for d in self.datasets:
14+
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15+
16+
class Data:
17+
def __init__(self, args):
18+
self.loader_train = None
19+
if not args.test_only:
20+
datasets = []
21+
for d in args.data_train:
22+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23+
m = import_module('data_new.' + module_name.lower())
24+
datasets.append(getattr(m, module_name)(args, name=d))
25+
26+
self.loader_train = dataloader.DataLoader(
27+
MyConcatDataset(datasets),
28+
batch_size=args.batch_size,
29+
shuffle=True,
30+
pin_memory=not args.cpu,
31+
num_workers=args.n_threads,
32+
)
33+
34+
self.loader_test = []
35+
for d in args.data_test:
36+
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
37+
m = import_module('data_new.benchmark')
38+
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39+
else:
40+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41+
m = import_module('data_new.' + module_name.lower())
42+
testset = getattr(m, module_name)(args, train=False, name=d)
43+
44+
self.loader_test.append(
45+
dataloader.DataLoader(
46+
testset,
47+
batch_size=1,
48+
shuffle=False,
49+
pin_memory=not args.cpu,
50+
num_workers=args.n_threads,
51+
)
52+
)

data_new/benchmark.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
3+
from data_new import common
4+
from data_new import srdata
5+
6+
import numpy as np
7+
8+
import torch
9+
import torch.utils.data as data
10+
11+
class Benchmark(srdata.SRData):
12+
def __init__(self, args, name='', train=True, benchmark=True):
13+
super(Benchmark, self).__init__(
14+
args, name=name, train=train, benchmark=True
15+
)
16+
17+
def _set_filesystem(self, dir_data):
18+
self.apath = os.path.join(dir_data, 'benchmark', self.name)
19+
self.dir_hr = os.path.join(self.apath, 'HR')
20+
if self.input_large:
21+
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22+
else:
23+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24+
self.ext = ('', '.png')
25+

data_new/common.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import random
2+
3+
import numpy as np
4+
import skimage.color as sc
5+
import pdb
6+
import torch
7+
8+
# import IPython
9+
10+
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
11+
# IPython.embed()
12+
# pdb.set_trace()
13+
ih, iw = args[0].shape[:2]
14+
if not input_large:
15+
p = scale if multi else 1
16+
tp = p * patch_size
17+
ip = tp // scale
18+
else:
19+
tp = patch_size
20+
ip = patch_size
21+
22+
ix = random.randrange(0, iw - ip + 1)
23+
iy = random.randrange(0, ih - ip + 1)
24+
25+
if not input_large:
26+
tx, ty = scale * ix, scale * iy
27+
else:
28+
tx, ty = ix, iy
29+
30+
ret = [
31+
args[0][iy:iy + ip, ix:ix + ip, :],
32+
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
33+
]
34+
35+
return ret
36+
37+
def set_channel(*args, n_channels=3):
38+
def _set_channel(img):
39+
if img.ndim == 2:
40+
img = np.expand_dims(img, axis=2)
41+
42+
c = img.shape[2]
43+
if n_channels == 1 and c == 3:
44+
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
45+
elif n_channels == 3 and c == 1:
46+
img = np.concatenate([img] * n_channels, 2)
47+
48+
return img
49+
50+
return [_set_channel(a) for a in args]
51+
52+
def np2Tensor(*args, rgb_range=255):
53+
def _np2Tensor(img):
54+
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
55+
tensor = torch.from_numpy(np_transpose).float()
56+
tensor.mul_(rgb_range / 255)
57+
58+
return tensor
59+
60+
return [_np2Tensor(a) for a in args]
61+
62+
def augment(*args, hflip=True, rot=True):
63+
hflip = hflip and random.random() < 0.5
64+
vflip = rot and random.random() < 0.5
65+
rot90 = rot and random.random() < 0.5
66+
67+
def _augment(img):
68+
if hflip: img = img[:, ::-1, :]
69+
if vflip: img = img[::-1, :, :]
70+
if rot90: img = img.transpose(1, 0, 2)
71+
72+
return img
73+
74+
return [_augment(a) for a in args]
75+

data_new/demo.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
from data_new import common
4+
5+
import numpy as np
6+
import imageio
7+
8+
import torch
9+
import torch.utils.data as data
10+
11+
class Demo(data.Dataset):
12+
def __init__(self, args, name='Demo', train=False, benchmark=False):
13+
self.args = args
14+
self.name = name
15+
self.scale = args.scale
16+
self.idx_scale = 0
17+
self.train = False
18+
self.benchmark = benchmark
19+
20+
self.filelist = []
21+
for f in os.listdir(args.dir_demo):
22+
if f.find('.png') >= 0 or f.find('.jp') >= 0:
23+
self.filelist.append(os.path.join(args.dir_demo, f))
24+
self.filelist.sort()
25+
26+
def __getitem__(self, idx):
27+
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28+
lr = imageio.imread(self.filelist[idx])
29+
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30+
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31+
32+
return lr_t, -1, filename
33+
34+
def __len__(self):
35+
return len(self.filelist)
36+
37+
def set_scale(self, idx_scale):
38+
self.idx_scale = idx_scale
39+

data_new/div2k.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
from data_new import srdata
3+
4+
class DIV2K(srdata.SRData):
5+
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6+
data_range = [r.split('-') for r in args.data_range.split('/')]
7+
if train:
8+
data_range = data_range[0]
9+
else:
10+
if args.test_only and len(data_range) == 1:
11+
data_range = data_range[0]
12+
else:
13+
data_range = data_range[1]
14+
15+
self.begin, self.end = list(map(lambda x: int(x), data_range))
16+
super(DIV2K, self).__init__(
17+
args, name=name, train=train, benchmark=benchmark
18+
)
19+
20+
def _scan(self):
21+
names_hr, names_lr = super(DIV2K, self)._scan()
22+
names_hr = names_hr[self.begin - 1:self.end]
23+
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24+
25+
return names_hr, names_lr
26+
27+
def _set_filesystem(self, dir_data):
28+
super(DIV2K, self)._set_filesystem(dir_data)
29+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30+
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31+
if self.input_large: self.dir_lr += 'L'
32+

data_new/div2kjpeg.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from data_new import srdata
3+
from data_new import div2k
4+
5+
class DIV2KJPEG(div2k.DIV2K):
6+
def __init__(self, args, name='', train=True, benchmark=False):
7+
self.q_factor = int(name.replace('DIV2K-Q', ''))
8+
super(DIV2KJPEG, self).__init__(
9+
args, name=name, train=train, benchmark=benchmark
10+
)
11+
12+
def _set_filesystem(self, dir_data):
13+
self.apath = os.path.join(dir_data, 'DIV2K')
14+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15+
self.dir_lr = os.path.join(
16+
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17+
)
18+
if self.input_large: self.dir_lr += 'L'
19+
self.ext = ('.png', '.jpg')
20+

data_new/sr291.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from data import srdata
2+
3+
class SR291(srdata.SRData):
4+
def __init__(self, args, name='SR291', train=True, benchmark=False):
5+
super(SR291, self).__init__(args, name=name)
6+

0 commit comments

Comments
 (0)