diff --git a/src/train.py b/src/train.py index c7628b3..87a6af7 100644 --- a/src/train.py +++ b/src/train.py @@ -45,8 +45,8 @@ assert test_body_vecs.shape[0] == test_title_vecs.shape[0] == test_labels.shape[0] # build model architecture -body_emb_size = 50 -title_emb_size = 50 +body_emb_size = 85 +title_emb_size = 85 batch_size = 900 epochs = 4 @@ -97,4 +97,4 @@ # save artifacts wandb.save(os.path.join(out_dir, '*')) wandb.save('/data/metadata.json') -wandb.save(os.path.join(input_dir, '*.dpkl')) \ No newline at end of file +wandb.save(os.path.join(input_dir, '*.dpkl'))