22
22
from utils_nlp .models .transformers .sequence_classification import Processor
23
23
24
24
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
25
+
26
+ # Source - https://github.com/nyu-mll/jiant/blob/master/scripts/download_glue_data.py
27
+ URL_JIANT_MNLI_TSV = "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce"
25
28
DATA_FILES = {
26
29
"train" : "multinli_1.0/multinli_1.0_train.jsonl" ,
27
30
"dev_matched" : "multinli_1.0/multinli_1.0_dev_matched.jsonl" ,
28
31
"dev_mismatched" : "multinli_1.0/multinli_1.0_dev_mismatched.jsonl" ,
29
32
}
30
33
31
34
32
- def download_file_and_extract (local_cache_path : str = "." , file_split : str = "train" ) -> None :
35
+ def download_file_and_extract (
36
+ local_cache_path : str = "." , file_split : str = "train"
37
+ ) -> None :
33
38
"""Download and extract the dataset files
34
39
35
40
Args:
@@ -46,6 +51,31 @@ def download_file_and_extract(local_cache_path: str = ".", file_split: str = "tr
46
51
extract_zip (os .path .join (local_cache_path , file_name ), local_cache_path )
47
52
48
53
54
+ def download_tsv_files_and_extract (local_cache_path : str = "." ) -> None :
55
+ """Download and extract the dataset files in tsv format from NYU Jiant
56
+ downloads both original and tsv formatted data.
57
+
58
+ Args:
59
+ local_cache_path (str [optional]) -- Directory to cache files to. Defaults to current working directory (default: {"."})
60
+
61
+ Returns:
62
+ None -- Nothing is returned
63
+ """
64
+ try :
65
+ folder_name = "MNLI"
66
+ file_name = f"{ folder_name } .zip"
67
+ maybe_download (URL_JIANT_MNLI_TSV , file_name , local_cache_path )
68
+ if not os .path .exists (os .path .join (local_cache_path , folder_name )):
69
+ extract_zip (os .path .join (local_cache_path , file_name ), local_cache_path )
70
+
71
+ # Clean up zip download
72
+ if os .path .exists (os .path .join (local_cache_path , file_name )):
73
+ os .remove (os .path .join (local_cache_path , file_name ))
74
+ except IOError as e :
75
+ raise (e )
76
+ print ("Downloaded file to: " , os .path .join (local_cache_path , folder_name ))
77
+
78
+
49
79
def load_pandas_df (local_cache_path = "." , file_split = "train" ):
50
80
"""Loads extracted dataset into pandas
51
81
Args:
@@ -61,10 +91,18 @@ def load_pandas_df(local_cache_path=".", file_split="train"):
61
91
download_file_and_extract (local_cache_path , file_split )
62
92
except Exception as e :
63
93
raise e
64
- return pd .read_json (os .path .join (local_cache_path , DATA_FILES [file_split ]), lines = True )
94
+ return pd .read_json (
95
+ os .path .join (local_cache_path , DATA_FILES [file_split ]), lines = True
96
+ )
65
97
66
98
67
- def get_generator (local_cache_path = "." , file_split = "train" , block_size = 10e6 , batch_size = 10e6 , num_batches = None ):
99
+ def get_generator (
100
+ local_cache_path = "." ,
101
+ file_split = "train" ,
102
+ block_size = 10e6 ,
103
+ batch_size = 10e6 ,
104
+ num_batches = None ,
105
+ ):
68
106
""" Returns an extracted dataset as a random batch generator that
69
107
yields pandas dataframes.
70
108
Args:
@@ -84,9 +122,13 @@ def get_generator(local_cache_path=".", file_split="train", block_size=10e6, bat
84
122
except Exception as e :
85
123
raise e
86
124
87
- loader = DaskJSONLoader (os .path .join (local_cache_path , DATA_FILES [file_split ]), block_size = block_size )
125
+ loader = DaskJSONLoader (
126
+ os .path .join (local_cache_path , DATA_FILES [file_split ]), block_size = block_size
127
+ )
88
128
89
- return loader .get_sequential_batches (batch_size = int (batch_size ), num_batches = num_batches )
129
+ return loader .get_sequential_batches (
130
+ batch_size = int (batch_size ), num_batches = num_batches
131
+ )
90
132
91
133
92
134
def load_tc_dataset (
@@ -161,17 +203,23 @@ def load_tc_dataset(
161
203
label_encoder .fit (all_df [label_col ])
162
204
163
205
if test_fraction < 0 or test_fraction >= 1.0 :
164
- logging .warning ("Invalid test fraction value: {}, changed to 0.25" .format (test_fraction ))
206
+ logging .warning (
207
+ "Invalid test fraction value: {}, changed to 0.25" .format (test_fraction )
208
+ )
165
209
test_fraction = 0.25
166
210
167
- train_df , test_df = train_test_split (all_df , train_size = (1.0 - test_fraction ), random_state = random_seed )
211
+ train_df , test_df = train_test_split (
212
+ all_df , train_size = (1.0 - test_fraction ), random_state = random_seed
213
+ )
168
214
169
215
if train_sample_ratio > 1.0 :
170
216
train_sample_ratio = 1.0
171
217
logging .warning ("Setting the training sample ratio to 1.0" )
172
218
elif train_sample_ratio < 0 :
173
219
logging .error ("Invalid training sample ration: {}" .format (train_sample_ratio ))
174
- raise ValueError ("Invalid training sample ration: {}" .format (train_sample_ratio ))
220
+ raise ValueError (
221
+ "Invalid training sample ration: {}" .format (train_sample_ratio )
222
+ )
175
223
176
224
if test_sample_ratio > 1.0 :
177
225
test_sample_ratio = 1.0
@@ -195,12 +243,16 @@ def load_tc_dataset(
195
243
train_dataset = processor .dataset_from_dataframe (
196
244
df = train_df , text_col = text_col , label_col = label_col , max_len = max_len ,
197
245
)
198
- train_dataloader = dataloader_from_dataset (train_dataset , batch_size = batch_size , num_gpus = num_gpus , shuffle = True )
246
+ train_dataloader = dataloader_from_dataset (
247
+ train_dataset , batch_size = batch_size , num_gpus = num_gpus , shuffle = True
248
+ )
199
249
200
250
test_dataset = processor .dataset_from_dataframe (
201
251
df = test_df , text_col = text_col , label_col = label_col , max_len = max_len ,
202
252
)
203
- test_dataloader = dataloader_from_dataset (test_dataset , batch_size = batch_size , num_gpus = num_gpus , shuffle = False )
253
+ test_dataloader = dataloader_from_dataset (
254
+ test_dataset , batch_size = batch_size , num_gpus = num_gpus , shuffle = False
255
+ )
204
256
205
257
return (train_dataloader , test_dataloader , label_encoder , test_labels )
206
258
0 commit comments