|
| 1 | +from importlib import import_module |
| 2 | +#from dataloader import MSDataLoader |
| 3 | +from torch.utils.data import dataloader |
| 4 | +from torch.utils.data import ConcatDataset |
| 5 | + |
| 6 | +# This is a simple wrapper function for ConcatDataset |
| 7 | +class MyConcatDataset(ConcatDataset): |
| 8 | + def __init__(self, datasets): |
| 9 | + super(MyConcatDataset, self).__init__(datasets) |
| 10 | + self.train = datasets[0].train |
| 11 | + |
| 12 | + def set_scale(self, idx_scale): |
| 13 | + for d in self.datasets: |
| 14 | + if hasattr(d, 'set_scale'): d.set_scale(idx_scale) |
| 15 | + |
| 16 | +class Data: |
| 17 | + def __init__(self, args): |
| 18 | + self.loader_train = None |
| 19 | + if not args.test_only: |
| 20 | + datasets = [] |
| 21 | + for d in args.data_train: |
| 22 | + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' |
| 23 | + m = import_module('data_new.' + module_name.lower()) |
| 24 | + datasets.append(getattr(m, module_name)(args, name=d)) |
| 25 | + |
| 26 | + self.loader_train = dataloader.DataLoader( |
| 27 | + MyConcatDataset(datasets), |
| 28 | + batch_size=args.batch_size, |
| 29 | + shuffle=True, |
| 30 | + pin_memory=not args.cpu, |
| 31 | + num_workers=args.n_threads, |
| 32 | + ) |
| 33 | + |
| 34 | + self.loader_test = [] |
| 35 | + for d in args.data_test: |
| 36 | + if d in ['Set5', 'Set14', 'B100', 'Urban100']: |
| 37 | + m = import_module('data_new.benchmark') |
| 38 | + testset = getattr(m, 'Benchmark')(args, train=False, name=d) |
| 39 | + else: |
| 40 | + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' |
| 41 | + m = import_module('data_new.' + module_name.lower()) |
| 42 | + testset = getattr(m, module_name)(args, train=False, name=d) |
| 43 | + |
| 44 | + self.loader_test.append( |
| 45 | + dataloader.DataLoader( |
| 46 | + testset, |
| 47 | + batch_size=1, |
| 48 | + shuffle=False, |
| 49 | + pin_memory=not args.cpu, |
| 50 | + num_workers=args.n_threads, |
| 51 | + ) |
| 52 | + ) |
0 commit comments