14
14
from torch .utils .data import DataLoader
15
15
from torchvision .transforms import CenterCrop , Compose , Lambda , RandomCrop , RandomHorizontalFlip
16
16
17
+ from video_transformers .pytorchvideo_wrapper .data .labeled_video_dataset import labeled_video_dataset
17
18
from video_transformers .pytorchvideo_wrapper .data .labeled_video_paths import LabeledVideoDataset , LabeledVideoPaths
18
19
from video_transformers .utils .extra import class_to_config
19
20
@@ -53,8 +54,8 @@ def __init__(
53
54
input_size: model input isze
54
55
means: mean of the video clip
55
56
stds: standard deviation of the video clip
56
- min_short_side_scale : minimum short side of the video clip
57
- max_short_side_scale : maximum short side of the video clip
57
+ min_short_side : minimum short side of the video clip
58
+ max_short_side : maximum short side of the video clip
58
59
horizontal_flip_p: probability of horizontal flip
59
60
clip_duration: duration of each video clip
60
61
@@ -77,10 +78,13 @@ def __init__(
77
78
self .clip_duration = clip_duration
78
79
79
80
# Transforms applied to train dataset.
81
+ def normalize_func (x ):
82
+ return x / 255.0
83
+
80
84
self .train_video_transform = Compose (
81
85
[
82
86
UniformTemporalSubsample (self .num_timesteps ),
83
- Lambda (lambda x : x / 255.0 ),
87
+ Lambda (normalize_func ),
84
88
Normalize (self .means , self .stds ),
85
89
RandomShortSideScale (
86
90
min_size = self .min_short_side ,
@@ -97,7 +101,7 @@ def __init__(
97
101
self .val_video_transform = Compose (
98
102
[
99
103
UniformTemporalSubsample (self .num_timesteps ),
100
- Lambda (lambda x : x / 255.0 ),
104
+ Lambda (normalize_func ),
101
105
Normalize (self .means , self .stds ),
102
106
ShortSideScale (self .min_short_side ),
103
107
CenterCrop (self .input_size ),
@@ -112,7 +116,6 @@ def __init__(
112
116
train_root : str ,
113
117
val_root : str ,
114
118
test_root : str = None ,
115
- train_dataset_multiplier : int = 1 ,
116
119
batch_size : int = 4 ,
117
120
num_workers : int = 4 ,
118
121
num_timesteps : int = 8 ,
@@ -158,8 +161,6 @@ def __init__(
158
161
Path to kinetics formatted train folder.
159
162
clip_duration: float
160
163
Duration of sampled clip for each video.
161
- train_dataset_multiplier: int
162
- Multipler for number of of random training data samples.
163
164
batch_size: int
164
165
Batch size for training and validation.
165
166
num_workers: int
@@ -196,7 +197,6 @@ def __init__(
196
197
self .train_root = train_root
197
198
self .val_root = val_root
198
199
self .test_root = test_root if test_root is not None else val_root
199
- self .train_dataset_multiplier = train_dataset_multiplier
200
200
self .labels = None
201
201
202
202
self .train_dataloader = self ._get_train_dataloader ()
@@ -212,18 +212,13 @@ def config(self) -> Dict:
212
212
return class_to_config (self , ignored_attrs = ("config" , "train_root" , "val_root" , "test_root" ))
213
213
214
214
def _get_train_dataloader (self ):
215
- labeled_video_paths = LabeledVideoPaths .from_path (self .train_root )
216
- labeled_video_paths .path_prefix = ""
217
- video_sampler = torch .utils .data .RandomSampler
218
215
clip_sampler = pytorchvideo .data .make_clip_sampler ("random" , self .preprocessor_config ["clip_duration" ])
219
- dataset = LabeledVideoDataset (
220
- labeled_video_paths ,
221
- clip_sampler ,
222
- video_sampler ,
223
- self .preprocessor .train_transform ,
216
+ dataset = labeled_video_dataset (
217
+ data_path = self .train_root ,
218
+ clip_sampler = clip_sampler ,
219
+ transform = self .preprocessor .train_transform ,
224
220
decode_audio = False ,
225
221
decoder = "pyav" ,
226
- dataset_multiplier = self .train_dataset_multiplier ,
227
222
)
228
223
self .labels = dataset .labels
229
224
return DataLoader (
@@ -234,18 +229,14 @@ def _get_train_dataloader(self):
234
229
)
235
230
236
231
def _get_val_dataloader (self ):
237
- labeled_video_paths = LabeledVideoPaths .from_path (self .val_root )
238
- labeled_video_paths .path_prefix = ""
239
- video_sampler = torch .utils .data .SequentialSampler
240
232
clip_sampler = pytorchvideo .data .clip_sampling .UniformClipSamplerTruncateFromStart (
241
233
clip_duration = self .preprocessor_config ["clip_duration" ],
242
234
truncation_duration = self .preprocessor_config ["clip_duration" ],
243
235
)
244
- dataset = LabeledVideoDataset (
245
- labeled_video_paths ,
246
- clip_sampler ,
247
- video_sampler ,
248
- self .preprocessor .val_transform ,
236
+ dataset = labeled_video_dataset (
237
+ data_path = self .val_root ,
238
+ clip_sampler = clip_sampler ,
239
+ transform = self .preprocessor .val_transform ,
249
240
decode_audio = False ,
250
241
decoder = "pyav" ,
251
242
)
@@ -257,18 +248,14 @@ def _get_val_dataloader(self):
257
248
)
258
249
259
250
def _get_test_dataloader (self ):
260
- labeled_video_paths = LabeledVideoPaths .from_path (self .test_root )
261
- labeled_video_paths .path_prefix = ""
262
- video_sampler = torch .utils .data .SequentialSampler
263
251
clip_sampler = pytorchvideo .data .clip_sampling .UniformClipSamplerTruncateFromStart (
264
252
clip_duration = self .preprocessor_config ["clip_duration" ],
265
253
truncation_duration = self .preprocessor_config ["clip_duration" ],
266
254
)
267
- dataset = LabeledVideoDataset (
268
- labeled_video_paths ,
269
- clip_sampler ,
270
- video_sampler ,
271
- self .preprocessor .val_transform ,
255
+ dataset = labeled_video_dataset (
256
+ data_path = self .test_root ,
257
+ clip_sampler = clip_sampler ,
258
+ transform = self .preprocessor .val_transform ,
272
259
decode_audio = False ,
273
260
decoder = "pyav" ,
274
261
)
0 commit comments