Skip to content

Commit 76e1da0

Browse files
authored
UT for src/guidellm/data/deserializers/file.py (#495)
## Summary <!-- Include a short paragraph of the changes introduced in this PR. If this PR requires additional context or rationale, explain why the changes are necessary. --> ## Details <!-- Provide a detailed list of all changes introduced in this pull request. --> - [ ] pytest tests/unit/data/deserializers/test_file.py ## Test Plan <!-- List the steps needed to test this PR. --> - ## Related Issues <!-- Link any relevant issues that this PR addresses. --> - Resolves # --- - [ ] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [ ] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents 0adaff9 + 97ad769 commit 76e1da0

File tree

1 file changed

+368
-0
lines changed

1 file changed

+368
-0
lines changed
Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
import csv
2+
import io
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
import pyarrow as pa
7+
import pyarrow.parquet as pq
8+
import pytest
9+
from datasets import Dataset, DatasetDict
10+
from pyarrow import ipc
11+
12+
from guidellm.data.deserializers.deserializer import DataNotSupportedError
13+
from guidellm.data.deserializers.file import (
14+
ArrowFileDatasetDeserializer,
15+
CSVFileDatasetDeserializer,
16+
DBFileDatasetDeserializer,
17+
HDF5FileDatasetDeserializer,
18+
JSONFileDatasetDeserializer,
19+
ParquetFileDatasetDeserializer,
20+
TarFileDatasetDeserializer,
21+
TextFileDatasetDeserializer,
22+
)
23+
24+
25+
def processor_factory():
26+
return None
27+
28+
29+
###################
30+
# Tests text file deserializer
31+
###################
32+
33+
34+
@pytest.mark.sanity
35+
def test_text_file_deserializer_success(tmp_path):
36+
# Arrange: create a temp text file
37+
file_path = tmp_path / "sample.txt"
38+
file_content = ["hello\n", "world\n"]
39+
file_path.write_text("".join(file_content))
40+
41+
deserializer = TextFileDatasetDeserializer()
42+
43+
dataset = deserializer(
44+
data=file_path,
45+
processor_factory=processor_factory(),
46+
random_seed=123,
47+
)
48+
49+
# Assert
50+
assert isinstance(dataset, Dataset)
51+
assert dataset["text"] == file_content
52+
assert len(dataset) == 2
53+
54+
55+
@pytest.mark.parametrize(
56+
"invalid_data",
57+
[
58+
123, # Not a path
59+
None, # Not a path
60+
{"file": "abc.txt"}, # Wrong type
61+
],
62+
)
63+
@pytest.mark.sanity
64+
def test_text_file_deserializer_invalid_type(invalid_data):
65+
deserializer = TextFileDatasetDeserializer()
66+
67+
with pytest.raises(DataNotSupportedError):
68+
deserializer(
69+
data=invalid_data,
70+
processor_factory=processor_factory(),
71+
random_seed=0,
72+
)
73+
74+
75+
@pytest.mark.sanity
76+
def test_text_file_deserializer_file_not_exists(tmp_path):
77+
deserializer = TextFileDatasetDeserializer()
78+
non_existent_file = tmp_path / "missing.txt"
79+
80+
with pytest.raises(DataNotSupportedError):
81+
deserializer(
82+
data=non_existent_file,
83+
processor_factory=processor_factory(),
84+
random_seed=0,
85+
)
86+
87+
88+
@pytest.mark.sanity
89+
def test_text_file_deserializer_not_a_file(tmp_path):
90+
deserializer = TextFileDatasetDeserializer()
91+
directory = tmp_path / "folder"
92+
directory.mkdir()
93+
94+
with pytest.raises(DataNotSupportedError):
95+
deserializer(
96+
data=directory,
97+
processor_factory=processor_factory(),
98+
random_seed=0,
99+
)
100+
101+
102+
@pytest.mark.sanity
103+
def test_text_file_deserializer_invalid_file_extension(tmp_path):
104+
deserializer = TextFileDatasetDeserializer()
105+
106+
file_path = tmp_path / "data.ttl"
107+
file_path.write_text("hello")
108+
109+
with pytest.raises(DataNotSupportedError):
110+
deserializer(
111+
data=file_path,
112+
processor_factory=processor_factory(),
113+
random_seed=0,
114+
)
115+
116+
117+
###################
118+
# Tests parquet file deserializer
119+
###################
120+
121+
122+
def create_parquet_file(path: Path):
123+
# Arrange: to create a minimal parquet file
124+
table = pa.Table.from_pydict({"text": ["hello", "world"]})
125+
pq.write_table(table, path)
126+
127+
128+
@pytest.mark.sanity
129+
def test_parquet_file_deserializer_success(tmp_path):
130+
file_path = tmp_path / "sample.parquet"
131+
create_parquet_file(file_path)
132+
133+
deserializer = ParquetFileDatasetDeserializer()
134+
135+
dataset = deserializer(
136+
data=file_path,
137+
processor_factory=processor_factory(),
138+
random_seed=42,
139+
)
140+
141+
# Assert
142+
assert isinstance(dataset, DatasetDict)
143+
assert dataset["train"].column_names == ["text"]
144+
assert dataset["train"]["text"] == ["hello", "world"]
145+
assert len(dataset["train"]["text"]) == 2
146+
147+
148+
@pytest.mark.sanity
149+
def test_parquet_file_deserializer_file_not_exists(tmp_path):
150+
deserializer = ParquetFileDatasetDeserializer()
151+
missing_file = tmp_path / "missing.parquet"
152+
153+
with pytest.raises(DataNotSupportedError):
154+
deserializer(
155+
data=missing_file,
156+
processor_factory=processor_factory(),
157+
random_seed=3,
158+
)
159+
160+
161+
###################
162+
# Tests csv file deserializer
163+
###################
164+
165+
166+
def create_csv_file(path: Path):
167+
"""Helper to create a minimal csv file."""
168+
output = io.StringIO()
169+
writer = csv.writer(output)
170+
writer.writerow(["text"])
171+
writer.writerow(["hello world"])
172+
with path.open("w") as f:
173+
f.write(output.getvalue())
174+
175+
176+
@pytest.mark.sanity
177+
def test_csv_file_deserializer_success(tmp_path):
178+
# Arrange: create a temp csv file
179+
file_path = tmp_path / "sample.csv"
180+
create_csv_file(file_path)
181+
182+
deserializer = CSVFileDatasetDeserializer()
183+
184+
dataset = deserializer(
185+
data=file_path,
186+
processor_factory=processor_factory(),
187+
random_seed=43,
188+
)
189+
190+
# Assert
191+
assert isinstance(dataset, DatasetDict)
192+
assert dataset["train"]["text"] == ["hello world"]
193+
assert len(["train"]) == 1
194+
195+
196+
###################
197+
# Tests json file deserializer
198+
###################
199+
200+
201+
@pytest.mark.sanity
202+
def test_json_file_deserializer_success(tmp_path):
203+
# Arrange: create a temp json file
204+
file_path = tmp_path / "sample.json"
205+
file_content = '{"text": "hello world"}\n'
206+
file_path.write_text("".join(file_content))
207+
208+
deserializer = JSONFileDatasetDeserializer()
209+
210+
dataset = deserializer(
211+
data=file_path,
212+
processor_factory=processor_factory(),
213+
random_seed=123,
214+
)
215+
216+
# Assert
217+
assert isinstance(dataset, DatasetDict)
218+
assert dataset["train"]["text"] == ["hello world"]
219+
assert len(dataset) == 1
220+
221+
222+
###################
223+
# Tests arrow file deserializer
224+
###################
225+
226+
227+
@pytest.mark.sanity
228+
def test_arrow_file_deserializer_success(monkeypatch, tmp_path):
229+
# Arrange: create a temp arrow file
230+
table = pa.Table.from_pydict({"text": ["hello", "world"]})
231+
file_path = tmp_path / "sample.arrow"
232+
233+
with (
234+
pa.OSFile(str(file_path), "wb") as sink,
235+
ipc.RecordBatchFileWriter(sink, table.schema) as writer,
236+
):
237+
writer.write_table(table)
238+
239+
deserializer = ArrowFileDatasetDeserializer()
240+
241+
dataset = deserializer(
242+
data=file_path,
243+
processor_factory=processor_factory(),
244+
random_seed=42,
245+
)
246+
247+
# assert
248+
assert isinstance(dataset, DatasetDict)
249+
assert "train" in dataset
250+
assert isinstance(dataset["train"], Dataset)
251+
assert dataset["train"].num_rows == 2
252+
253+
254+
###################
255+
# Tests HDF5 file deserializer
256+
###################
257+
258+
259+
@pytest.mark.skip(
260+
reason="add pyproject extras group in the future \
261+
to install hdf5 dependency such as pytables & h5py"
262+
)
263+
def test_hdf5_file_deserializer_success(tmp_path):
264+
df_sample = pd.DataFrame({"text": ["hello", "world"]})
265+
file_path = tmp_path / "sample.h5"
266+
df_sample.to_hdf(str(file_path), key="data", mode="w", format="fixed")
267+
268+
deserializer = HDF5FileDatasetDeserializer()
269+
270+
dataset = deserializer(
271+
data=file_path,
272+
processor_factory=processor_factory(),
273+
random_seed=1,
274+
)
275+
276+
# assert
277+
assert isinstance(dataset, Dataset)
278+
assert dataset.num_rows == 2
279+
assert dataset["text"] == ["hello", "world"]
280+
281+
282+
##################
283+
# Tests DB file deserializer
284+
###################
285+
286+
287+
@pytest.mark.skip(reason="issue: #492")
288+
def test_db_file_deserializer_success(monkeypatch, tmp_path):
289+
import sqlite3
290+
291+
def create_sqlite_db(path: Path):
292+
conn = sqlite3.connect(path)
293+
cur = conn.cursor()
294+
cur.execute("CREATE TABLE samples (text TEXT)")
295+
cur.execute("INSERT INTO samples (text) VALUES ('hello')")
296+
cur.execute("INSERT INTO samples (text) VALUES ('world')")
297+
conn.commit()
298+
conn.close()
299+
300+
# Arrange: create a valid .db file
301+
db_path = tmp_path / "sample.db"
302+
create_sqlite_db(db_path)
303+
304+
# arrange: mock Dataset.from_sql return one dataset
305+
mocked_ds = Dataset.from_dict({"text": ["hello", "world"]})
306+
307+
def mock_from_sql(sql, con, **kwargs):
308+
assert sql == "SELECT * FROM samples"
309+
assert con == (str(db_path))
310+
return mocked_ds
311+
312+
monkeypatch.setattr("datasets.Dataset.from_sql", mock_from_sql)
313+
314+
deserializer = DBFileDatasetDeserializer()
315+
316+
dataset = deserializer(
317+
data=db_path,
318+
processor_factory=processor_factory(),
319+
random_seed=1,
320+
)
321+
322+
# Assert: result is of type Dataset
323+
assert isinstance(dataset, Dataset)
324+
assert dataset.num_rows == 2
325+
assert dataset["text"] == ["hello", "world"]
326+
327+
328+
##################
329+
# Tests Tar file deserializer
330+
###################
331+
332+
333+
def create_simple_tar(tar_path: str):
334+
import tarfile
335+
336+
# create tar 文件 in write mode
337+
with tarfile.open(tar_path, "w") as tar:
338+
# write content to be added to the tar file
339+
content = b"hello world\nthis is a tar file\n"
340+
341+
# using BytesIO
342+
data_stream = io.BytesIO(content)
343+
344+
# tarinfo: file description info
345+
info = tarfile.TarInfo(name="sample.txt")
346+
info.size = len(content)
347+
348+
# write file to tar archive
349+
tar.addfile(info, data_stream)
350+
351+
352+
@pytest.mark.sanity
353+
def test_tar_file_deserializer_success(tmp_path):
354+
file_path = tmp_path / "sample.tar"
355+
create_simple_tar(file_path)
356+
357+
deserializer = TarFileDatasetDeserializer()
358+
359+
dataset = deserializer(
360+
data=file_path,
361+
processor_factory=processor_factory(),
362+
random_seed=43,
363+
)
364+
365+
assert isinstance(dataset, DatasetDict)
366+
assert "train" in dataset
367+
assert isinstance(dataset["train"], Dataset)
368+
assert dataset["train"].num_rows == 1

0 commit comments

Comments
 (0)