Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit d2be374

Browse files
author
saidbleik
committed
minor edits
1 parent 37f804b commit d2be374

File tree

4 files changed

+26
-41
lines changed

4 files changed

+26
-41
lines changed

utils_nlp/dataset/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99

1010

1111
class Split(str, Enum):
12-
TRAIN : str = "train"
13-
DEV : str = "dev"
14-
TEST : str = "test"
12+
TRAIN: str = "train"
13+
DEV: str = "dev"
14+
TEST: str = "test"

utils_nlp/dataset/snli.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,15 @@
2020
LABEL_COL = "score"
2121

2222

23-
def load_pandas_df(
24-
local_cache_path=None, file_split=Split.TRAIN, file_type="txt", nrows=None
25-
):
23+
def load_pandas_df(local_cache_path=None, file_split=Split.TRAIN, file_type="txt", nrows=None):
2624
"""
2725
Loads the SNLI dataset as pd.DataFrame
2826
Download the dataset from "https://nlp.stanford.edu/projects/snli/snli_1.0.zip", unzip, and load
2927
3028
Args:
3129
local_cache_path (str): Path (directory or a zip file) to cache the downloaded zip file.
32-
If None, all the intermediate files will be stored in a temporary directory and removed
33-
after use.
30+
If None, all the intermediate files will be stored in a temporary directory and removed
31+
after use.
3432
file_split (str): File split to load, defaults to "train"
3533
file_type (str): File type to load, defaults to "txt"
3634
nrows (int): Number of rows to load, defaults to None (in which all rows will be returned)
@@ -78,12 +76,8 @@ def _maybe_download_and_extract(zip_path, file_split, file_type):
7876
extract_path = os.path.join(dir_path, file_name)
7977

8078
if not os.path.exists(extract_path):
81-
dpath = download_snli(zip_path)
82-
extract_snli(
83-
zip_path,
84-
source_path=SNLI_DIRNAME + "/" + file_name,
85-
dest_path=extract_path,
86-
)
79+
_ = download_snli(zip_path)
80+
extract_snli(zip_path, source_path=SNLI_DIRNAME + "/" + file_name, dest_path=extract_path)
8781

8882
return extract_path
8983

@@ -143,24 +137,20 @@ def clean_cols(df):
143137
)
144138

145139
snli_df = snli_df.rename(
146-
columns={
147-
"sentence1": S1_COL,
148-
"sentence2": S2_COL,
149-
"gold_label": LABEL_COL,
150-
}
140+
columns={"sentence1": S1_COL, "sentence2": S2_COL, "gold_label": LABEL_COL}
151141
)
152142

153143
return snli_df
154144

155145

156146
def clean_rows(df, label_col=LABEL_COL):
157147
"""Drop badly formatted rows from the input dataframe
158-
148+
159149
Args:
160150
df (pd.DataFrame): Input dataframe
161151
label_col (str): Name of label column.
162-
Defaults to the standardized column name that is set after running the clean_col method.
163-
152+
Defaults to the standardized column name that is set after running the clean_col method.
153+
164154
Returns:
165155
pd.DataFrame
166156
"""
@@ -169,23 +159,23 @@ def clean_rows(df, label_col=LABEL_COL):
169159

170160
return snli_df
171161

162+
172163
def clean_df(df, label_col=LABEL_COL):
173164
df = clean_cols(df)
174165
df = clean_rows(df, label_col)
175166

176167
return df
177168

178-
def load_azureml_df(
179-
local_cache_path=None, file_split=Split.TRAIN, file_type="txt"
180-
):
169+
170+
def load_azureml_df(local_cache_path=None, file_split=Split.TRAIN, file_type="txt"):
181171
"""
182172
Loads the SNLI dataset as AzureML dataflow object
183173
Download the dataset from "https://nlp.stanford.edu/projects/snli/snli_1.0.zip", unzip, and load.
184174
185175
Args:
186176
local_cache_path (str): Path (directory or a zip file) to cache the downloaded zip file.
187-
If None, all the intermediate files will be stored in a temporary directory and removed
188-
after use.
177+
If None, all the intermediate files will be stored in a temporary directory and removed
178+
after use.
189179
file_split (str): File split to load. One of (dev, test, train)
190180
file_type (str): File type to load. One of (txt, jsonl)
191181

utils_nlp/dataset/wikigold.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
)
1515

1616

17-
def load_train_test_dfs(
18-
local_cache_path="./", test_percentage=0.5, random_seed=None
19-
):
17+
def load_train_test_dfs(local_cache_path="./", test_percentage=0.5, random_seed=None):
2018
"""
2119
Get the training and testing data frames based on test_percentage.
2220
@@ -58,13 +56,9 @@ def load_train_test_dfs(
5856
train_sentence_list = sentence_list[test_sentence_count:]
5957
train_labels_list = labels_list[test_sentence_count:]
6058

61-
train_df = pd.DataFrame(
62-
{"sentence": train_sentence_list, "labels": train_labels_list}
63-
)
59+
train_df = pd.DataFrame({"sentence": train_sentence_list, "labels": train_labels_list})
6460

65-
test_df = pd.DataFrame(
66-
{"sentence": test_sentence_list, "labels": test_labels_list}
67-
)
61+
test_df = pd.DataFrame({"sentence": test_sentence_list, "labels": test_labels_list})
6862

6963
return (train_df, test_df)
7064

utils_nlp/dataset/xnli_torch_dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ def __init__(
6060
Load the dataset here
6161
Args:
6262
file_split (str, optional):The subset to load.
63-
One of: {"train", "dev", "test"}
64-
Defaults to "train".
63+
One of: {"train", "dev", "test"}
64+
Defaults to "train".
6565
cache_dir (str, optional):Path to store the data.
66-
Defaults to "./".
66+
Defaults to "./".
6767
language(str):Language required to load which xnli file (eg - "en", "zh")
6868
to_lowercase(bool):flag to convert samples in dataset to lowercase
6969
tok_language(Language, optional): language (Language, optional): The pretrained model's language.
70-
Defaults to Language.ENGLISH.
71-
data_percent_used(float, optional): Data used to create Torch Dataset.Defaults to "1.0" which is 100% data
70+
Defaults to Language.ENGLISH.
71+
data_percent_used(float, optional): Data used to create Torch Dataset.
72+
Defaults to "1.0" which is 100% data
7273
"""
7374
if file_split not in VALID_FILE_SPLIT:
7475
raise ValueError("The file split is not part of ", VALID_FILE_SPLIT)

0 commit comments

Comments
 (0)