Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 2091c38

Browse files
author
Emmanuel Awa
committed
Feat: add functionality to download MNLI preprocessed tsv data.
Leverage NYU Jiant Toolkit preprocessed tsv data source
1 parent 150909f commit 2091c38

File tree

1 file changed

+62
-10
lines changed

1 file changed

+62
-10
lines changed

utils_nlp/dataset/multinli.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,19 @@
2222
from utils_nlp.models.transformers.sequence_classification import Processor
2323

2424
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
25+
26+
# Source - https://github.com/nyu-mll/jiant/blob/master/scripts/download_glue_data.py
27+
URL_JIANT_MNLI_TSV = "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce"
2528
DATA_FILES = {
2629
"train": "multinli_1.0/multinli_1.0_train.jsonl",
2730
"dev_matched": "multinli_1.0/multinli_1.0_dev_matched.jsonl",
2831
"dev_mismatched": "multinli_1.0/multinli_1.0_dev_mismatched.jsonl",
2932
}
3033

3134

32-
def download_file_and_extract(local_cache_path: str = ".", file_split: str = "train") -> None:
35+
def download_file_and_extract(
36+
local_cache_path: str = ".", file_split: str = "train"
37+
) -> None:
3338
"""Download and extract the dataset files
3439
3540
Args:
@@ -46,6 +51,31 @@ def download_file_and_extract(local_cache_path: str = ".", file_split: str = "tr
4651
extract_zip(os.path.join(local_cache_path, file_name), local_cache_path)
4752

4853

54+
def download_tsv_files_and_extract(local_cache_path: str = ".") -> None:
55+
"""Download and extract the dataset files in tsv format from NYU Jiant
56+
downloads both original and tsv formatted data.
57+
58+
Args:
59+
local_cache_path (str [optional]) -- Directory to cache files to. Defaults to current working directory (default: {"."})
60+
61+
Returns:
62+
None -- Nothing is returned
63+
"""
64+
try:
65+
folder_name = "MNLI"
66+
file_name = f"{folder_name}.zip"
67+
maybe_download(URL_JIANT_MNLI_TSV, file_name, local_cache_path)
68+
if not os.path.exists(os.path.join(local_cache_path, folder_name)):
69+
extract_zip(os.path.join(local_cache_path, file_name), local_cache_path)
70+
71+
# Clean up zip download
72+
if os.path.exists(os.path.join(local_cache_path, file_name)):
73+
os.remove(os.path.join(local_cache_path, file_name))
74+
except IOError as e:
75+
raise (e)
76+
print("Downloaded file to: ", os.path.join(local_cache_path, folder_name))
77+
78+
4979
def load_pandas_df(local_cache_path=".", file_split="train"):
5080
"""Loads extracted dataset into pandas
5181
Args:
@@ -61,10 +91,18 @@ def load_pandas_df(local_cache_path=".", file_split="train"):
6191
download_file_and_extract(local_cache_path, file_split)
6292
except Exception as e:
6393
raise e
64-
return pd.read_json(os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True)
94+
return pd.read_json(
95+
os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True
96+
)
6597

6698

67-
def get_generator(local_cache_path=".", file_split="train", block_size=10e6, batch_size=10e6, num_batches=None):
99+
def get_generator(
100+
local_cache_path=".",
101+
file_split="train",
102+
block_size=10e6,
103+
batch_size=10e6,
104+
num_batches=None,
105+
):
68106
""" Returns an extracted dataset as a random batch generator that
69107
yields pandas dataframes.
70108
Args:
@@ -84,9 +122,13 @@ def get_generator(local_cache_path=".", file_split="train", block_size=10e6, bat
84122
except Exception as e:
85123
raise e
86124

87-
loader = DaskJSONLoader(os.path.join(local_cache_path, DATA_FILES[file_split]), block_size=block_size)
125+
loader = DaskJSONLoader(
126+
os.path.join(local_cache_path, DATA_FILES[file_split]), block_size=block_size
127+
)
88128

89-
return loader.get_sequential_batches(batch_size=int(batch_size), num_batches=num_batches)
129+
return loader.get_sequential_batches(
130+
batch_size=int(batch_size), num_batches=num_batches
131+
)
90132

91133

92134
def load_tc_dataset(
@@ -161,17 +203,23 @@ def load_tc_dataset(
161203
label_encoder.fit(all_df[label_col])
162204

163205
if test_fraction < 0 or test_fraction >= 1.0:
164-
logging.warning("Invalid test fraction value: {}, changed to 0.25".format(test_fraction))
206+
logging.warning(
207+
"Invalid test fraction value: {}, changed to 0.25".format(test_fraction)
208+
)
165209
test_fraction = 0.25
166210

167-
train_df, test_df = train_test_split(all_df, train_size=(1.0 - test_fraction), random_state=random_seed)
211+
train_df, test_df = train_test_split(
212+
all_df, train_size=(1.0 - test_fraction), random_state=random_seed
213+
)
168214

169215
if train_sample_ratio > 1.0:
170216
train_sample_ratio = 1.0
171217
logging.warning("Setting the training sample ratio to 1.0")
172218
elif train_sample_ratio < 0:
173219
logging.error("Invalid training sample ration: {}".format(train_sample_ratio))
174-
raise ValueError("Invalid training sample ration: {}".format(train_sample_ratio))
220+
raise ValueError(
221+
"Invalid training sample ration: {}".format(train_sample_ratio)
222+
)
175223

176224
if test_sample_ratio > 1.0:
177225
test_sample_ratio = 1.0
@@ -195,12 +243,16 @@ def load_tc_dataset(
195243
train_dataset = processor.dataset_from_dataframe(
196244
df=train_df, text_col=text_col, label_col=label_col, max_len=max_len,
197245
)
198-
train_dataloader = dataloader_from_dataset(train_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=True)
246+
train_dataloader = dataloader_from_dataset(
247+
train_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=True
248+
)
199249

200250
test_dataset = processor.dataset_from_dataframe(
201251
df=test_df, text_col=text_col, label_col=label_col, max_len=max_len,
202252
)
203-
test_dataloader = dataloader_from_dataset(test_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=False)
253+
test_dataloader = dataloader_from_dataset(
254+
test_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=False
255+
)
204256

205257
return (train_dataloader, test_dataloader, label_encoder, test_labels)
206258

0 commit comments

Comments
 (0)