@@ -14,15 +14,22 @@ def train(args):
14
14
dataset = SeriesDataset (args .data_dir )
15
15
16
16
train_size = int (0.8 * len (dataset ))
17
-
18
17
train_dataset , val_dataset = torch .utils .data .random_split (dataset , [train_size , len (dataset ) - train_size ])
19
18
20
19
train_dataloader = DataLoader (train_dataset , batch_size = args .batch_size , shuffle = True )
21
20
val_dataloader = DataLoader (val_dataset , batch_size = args .batch_size , shuffle = False )
22
21
23
- trainer = lightning .Trainer (default_root_dir = args .checkpoint_dir , max_epochs = 100 )
22
+ trainer = lightning .Trainer (default_root_dir = args .checkpoint_dir , max_epochs = args . epochs )
24
23
trainer .fit (model = model , train_dataloaders = train_dataloader , val_dataloaders = val_dataloader )
25
24
25
+ checkpoint = {
26
+ 'hyperparameters' : vars (args ),
27
+ 'state_dict' : model .state_dict (),
28
+ }
29
+ torch .save (checkpoint , f'{ args .checkpoint_dir } /model_checkpoint.pth' )
30
+
31
+ trainer .test (model = model , dataloaders = val_dataloader )
32
+
26
33
27
34
def main ():
28
35
parser = argparse .ArgumentParser (description = 'CNN-Transformer Time Series Classification' )
@@ -32,12 +39,13 @@ def main():
32
39
parser .add_argument ('--num_layers' , type = int , default = 2 , help = 'Number of transformer layers' )
33
40
parser .add_argument ('--num_classes' , type = int , default = 5 , help = 'Number of classes for classification' )
34
41
parser .add_argument ('--batch_size' , type = int , default = 8 , help = 'Batch size for training' )
35
-
36
42
parser .add_argument ('--data_dir' , type = str , default = "Sara_dataset/" , help = 'Dataset directory for loading series' )
37
43
parser .add_argument ('--checkpoint_dir' , type = str , default = "checkpoints/" , help = 'Directory to save checkpoints' )
38
44
parser .add_argument ('--resume_training' , action = 'store_true' , help = 'Resume training from checkpoint' )
45
+ parser .add_argument ('--epochs' , type = int , default = 100 , help = 'Number of epochs to train for' )
39
46
40
- train (parser .parse_args ())
47
+ args = parser .parse_args ()
48
+ train (args )
41
49
42
50
43
51
if __name__ == '__main__' :
0 commit comments