Skip to content

Commit 9798e0d

Browse files
authored
add unit tests for utils module (#616)
* + add unit tests for utils module * fix some bugs * * reset the logger in setup method * * update * * update * * update
1 parent 628e355 commit 9798e0d

24 files changed

+1259
-11
lines changed

data_juicer/utils/asset_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def load_words_asset(words_dir: str, words_type: str):
4848
logger.info(f'Specified {words_dir} does not contain '
4949
f'any {words_type} files in json format, now '
5050
'download the one cached by data_juicer team')
51+
if words_type not in ASSET_LINKS:
52+
raise ValueError(f'{words_type} is not in remote server.')
5153
response = requests.get(ASSET_LINKS[words_type])
5254
words_dict = response.json()
5355
# cache the asset file locally

data_juicer/utils/compress.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class FileLock(HF_FileLock):
2323
def _release(self):
2424
super()._release()
2525
try:
26-
# logger.debug(f'Remove {self._lock_file}')
27-
os.remove(self._lock_file)
26+
# logger.debug(f'Remove {self.lock_file}')
27+
os.remove(self.lock_file)
2828
# The file is already deleted and that's what we want.
2929
except OSError:
3030
pass
@@ -497,4 +497,4 @@ def decompress(ds, fingerprints=None, num_proc=1):
497497

498498

499499
def cleanup_compressed_cache_files(ds):
500-
CacheCompressManager().cleanup_cache_files(ds)
500+
CacheCompressManager(cache_utils.CACHE_COMPRESS).cleanup_cache_files(ds)

data_juicer/utils/mm_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def iou(box1, box2):
160160
ix_max = min(x1_max, x2_max)
161161
iy_min = max(y1_min, y2_min)
162162
iy_max = min(y1_max, y2_max)
163-
intersection = max(0, (ix_max - ix_min) * (iy_max - iy_min))
163+
intersection = max(0, max(0, ix_max - ix_min) * max(0, iy_max - iy_min))
164164
union = area1 + area2 - intersection
165-
return 1.0 * intersection / union
165+
return 1.0 * intersection / union if union != 0 else 0.0
166166

167167

168168
def calculate_resized_dimensions(
@@ -207,7 +207,7 @@ def calculate_resized_dimensions(
207207

208208
# Determine final dimensions based on original orientation
209209
resized_dimensions = ((new_short_edge,
210-
new_long_edge) if width <= height else
210+
new_long_edge) if width >= height else
211211
(new_long_edge, new_short_edge))
212212

213213
# Ensure final dimensions are divisible by the specified value

data_juicer/utils/registry.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
# https://github.com/modelscope/modelscope/blob/master/modelscope/utils/registry.py
1818
# --------------------------------------------------------
1919

20-
from loguru import logger
21-
2220

2321
class Registry(object):
2422
"""This class is used to register some modules to registry by a repo
@@ -53,8 +51,7 @@ def modules(self):
5351

5452
def list(self):
5553
"""Logging the list of module in current registry."""
56-
for m in self._modules.keys():
57-
logger.info(f'{self._name}\t{m}')
54+
return list(self._modules.keys())
5855

5956
def get(self, module_key):
6057
"""

tests/run.py

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def main():
9797
runner = unittest.TextTestRunner()
9898
test_suite = gather_test_cases(os.path.abspath(args.test_dir),
9999
args.pattern, args.tag, args.mode)
100+
logger.info(f'There are {len(test_suite._tests)} test cases to run.')
100101
res = runner.run(test_suite)
101102

102103
cov.stop()

tests/utils/test_asset_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import json
3+
import unittest
4+
5+
from data_juicer.utils.asset_utils import load_words_asset
6+
7+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
8+
9+
class LoadWordsAssetTest(DataJuicerTestCaseBase):
10+
11+
def setUp(self) -> None:
12+
self.temp_output_path = 'tmp/test_asset_utils/'
13+
14+
def tearDown(self):
15+
if os.path.exists(self.temp_output_path):
16+
os.system(f'rm -rf {self.temp_output_path}')
17+
18+
def test_basic_func(self):
19+
# download assets from the remote server
20+
words_dict = load_words_asset(self.temp_output_path, 'stopwords')
21+
self.assertTrue(len(words_dict) > 0)
22+
self.assertTrue(os.path.exists(os.path.join(self.temp_output_path, 'stopwords.json')))
23+
24+
words_dict = load_words_asset(self.temp_output_path, 'flagged_words')
25+
self.assertTrue(len(words_dict) > 0)
26+
self.assertTrue(os.path.exists(os.path.join(self.temp_output_path, 'flagged_words.json')))
27+
28+
# non-existing asset
29+
with self.assertRaises(ValueError):
30+
load_words_asset(self.temp_output_path, 'non_existing_asset')
31+
32+
def test_load_from_existing_file(self):
33+
os.makedirs(self.temp_output_path, exist_ok=True)
34+
temp_asset = os.path.join(self.temp_output_path, 'temp_asset.json')
35+
with open(temp_asset, 'w') as fout:
36+
json.dump({'test_key': ['test_val']}, fout)
37+
38+
words_list = load_words_asset(self.temp_output_path, 'temp_asset')
39+
self.assertEqual(len(words_list), 1)
40+
self.assertEqual(len(words_list['test_key']), 1)
41+
42+
def test_load_from_serial_files(self):
43+
os.makedirs(self.temp_output_path, exist_ok=True)
44+
temp_asset = os.path.join(self.temp_output_path, 'temp_asset_v1.json')
45+
with open(temp_asset, 'w') as fout:
46+
json.dump({'test_key': ['test_val_1']}, fout)
47+
temp_asset = os.path.join(self.temp_output_path, 'temp_asset_v2.json')
48+
with open(temp_asset, 'w') as fout:
49+
json.dump({'test_key': ['test_val_2']}, fout)
50+
51+
words_list = load_words_asset(self.temp_output_path, 'temp_asset')
52+
self.assertEqual(len(words_list), 1)
53+
self.assertEqual(len(words_list['test_key']), 2)
54+
55+
56+
if __name__ == '__main__':
57+
unittest.main()
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import unittest
2+
3+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
4+
5+
class AutoInstallMappingTest(DataJuicerTestCaseBase):
6+
7+
def test_placeholder(self):
8+
pass
9+
10+
11+
if __name__ == '__main__':
12+
unittest.main()
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import unittest
2+
3+
from data_juicer.utils.auto_install_utils import _is_module_installed, _is_package_installed
4+
5+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
6+
7+
class IsXXXInstalledFuncsTest(DataJuicerTestCaseBase):
8+
9+
def test_is_module_installed(self):
10+
self.assertTrue(_is_module_installed('datasets'))
11+
self.assertTrue(_is_module_installed('simhash'))
12+
13+
self.assertFalse(_is_module_installed('non_existent_module'))
14+
15+
def test_is_package_installed(self):
16+
self.assertTrue(_is_package_installed('datasets'))
17+
self.assertTrue(_is_package_installed('ram@git+https://github.com/xinyu1205/recognize-anything.git'))
18+
self.assertTrue(_is_package_installed('scenedetect[opencv]'))
19+
20+
self.assertFalse(_is_package_installed('non_existent_package'))
21+
22+
23+
if __name__ == '__main__':
24+
unittest.main()
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
3+
from data_juicer.utils.availability_utils import _is_package_available
4+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
5+
6+
class AvailabilityUtilsTest(DataJuicerTestCaseBase):
7+
8+
def test_is_package_available(self):
9+
exist = _is_package_available('fsspec')
10+
self.assertTrue(exist)
11+
exist, version = _is_package_available('fsspec', return_version=True)
12+
self.assertTrue(exist)
13+
self.assertEqual(version, '2023.5.0')
14+
15+
exist = _is_package_available('non_existing_package')
16+
self.assertFalse(exist)
17+
exist, version = _is_package_available('non_existing_package', return_version=True)
18+
self.assertFalse(exist)
19+
self.assertEqual(version, 'N/A')
20+
21+
22+
if __name__ == '__main__':
23+
unittest.main()

tests/utils/test_cache_utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
3+
import datasets
4+
5+
from data_juicer.utils.cache_utils import DatasetCacheControl, dataset_cache_control
6+
7+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
8+
9+
class DatasetCacheControlTest(DataJuicerTestCaseBase):
10+
11+
def test_basic_func(self):
12+
self.assertTrue(datasets.is_caching_enabled())
13+
with DatasetCacheControl(on=False):
14+
self.assertFalse(datasets.is_caching_enabled())
15+
self.assertTrue(datasets.is_caching_enabled())
16+
17+
with DatasetCacheControl(on=False):
18+
self.assertFalse(datasets.is_caching_enabled())
19+
with DatasetCacheControl(on=True):
20+
self.assertTrue(datasets.is_caching_enabled())
21+
self.assertFalse(datasets.is_caching_enabled())
22+
self.assertTrue(datasets.is_caching_enabled())
23+
24+
def test_decorator(self):
25+
26+
@dataset_cache_control(on=False)
27+
def check():
28+
return datasets.is_caching_enabled()
29+
30+
self.assertTrue(datasets.is_caching_enabled())
31+
self.assertFalse(check())
32+
self.assertTrue(datasets.is_caching_enabled())
33+
34+
35+
if __name__ == '__main__':
36+
unittest.main()

tests/utils/test_ckpt_utils.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import unittest
3+
import json
4+
5+
from data_juicer.core.data import NestedDataset
6+
from data_juicer.utils.ckpt_utils import CheckpointManager
7+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
8+
9+
class CkptUtilsTest(DataJuicerTestCaseBase):
10+
11+
def setUp(self) -> None:
12+
self.temp_output_path = 'tmp/test_ckpt_utils/'
13+
14+
def tearDown(self):
15+
if os.path.exists(self.temp_output_path):
16+
os.system(f'rm -rf {self.temp_output_path}')
17+
18+
def test_basic_func(self):
19+
ckpt_path = os.path.join(self.temp_output_path, 'ckpt_1')
20+
manager = CheckpointManager(ckpt_path, original_process_list=[
21+
{'test_op_1': {'test_key': 'test_value_1'}},
22+
{'test_op_2': {'test_key': 'test_value_2'}},
23+
])
24+
self.assertEqual(manager.get_left_process_list(), [
25+
{'test_op_1': {'test_key': 'test_value_1'}},
26+
{'test_op_2': {'test_key': 'test_value_2'}},
27+
])
28+
self.assertFalse(manager.ckpt_available)
29+
30+
self.assertFalse(manager.check_ckpt())
31+
os.makedirs(ckpt_path, exist_ok=True)
32+
os.makedirs(os.path.join(ckpt_path, 'latest'), exist_ok=True)
33+
with open(os.path.join(ckpt_path, 'ckpt_op.json'), 'w') as fout:
34+
json.dump([
35+
{'test_op_1': {'test_key': 'test_value_1'}},
36+
], fout)
37+
self.assertTrue(manager.check_ops_to_skip())
38+
39+
manager = CheckpointManager(ckpt_path, original_process_list=[
40+
{'test_op_1': {'test_key': 'test_value_1'}},
41+
{'test_op_2': {'test_key': 'test_value_2'}},
42+
])
43+
with open(os.path.join(ckpt_path, 'ckpt_op.json'), 'w') as fout:
44+
json.dump([
45+
{'test_op_1': {'test_key': 'test_value_1'}},
46+
{'test_op_2': {'test_key': 'test_value_2'}},
47+
], fout)
48+
self.assertFalse(manager.check_ops_to_skip())
49+
50+
def test_different_ops(self):
51+
ckpt_path = os.path.join(self.temp_output_path, 'ckpt_2')
52+
os.makedirs(ckpt_path, exist_ok=True)
53+
os.makedirs(os.path.join(ckpt_path, 'latest'), exist_ok=True)
54+
with open(os.path.join(ckpt_path, 'ckpt_op.json'), 'w') as fout:
55+
json.dump([
56+
{'test_op_2': {'test_key': 'test_value_2'}},
57+
], fout)
58+
manager = CheckpointManager(ckpt_path, original_process_list=[
59+
{'test_op_1': {'test_key': 'test_value_1'}},
60+
{'test_op_2': {'test_key': 'test_value_2'}},
61+
])
62+
self.assertFalse(manager.ckpt_available)
63+
64+
def test_save_and_load_ckpt(self):
65+
ckpt_path = os.path.join(self.temp_output_path, 'ckpt_3')
66+
test_data = {
67+
'text': ['text1', 'text2', 'text3'],
68+
}
69+
dataset = NestedDataset.from_dict(test_data)
70+
manager = CheckpointManager(ckpt_path, original_process_list=[])
71+
self.assertFalse(os.path.exists(os.path.join(manager.ckpt_ds_dir, 'dataset_info.json')))
72+
manager.record({'test_op_1': {'test_key': 'test_value_1'}})
73+
manager.save_ckpt(dataset)
74+
self.assertTrue(os.path.exists(os.path.join(manager.ckpt_ds_dir, 'dataset_info.json')))
75+
self.assertTrue(os.path.exists(manager.ckpt_op_record))
76+
loaded_ckpt = manager.load_ckpt()
77+
self.assertDatasetEqual(dataset, loaded_ckpt)
78+
79+
80+
if __name__ == '__main__':
81+
unittest.main()

tests/utils/test_common_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import unittest
2+
import sys
3+
4+
from data_juicer.utils.common_utils import (
5+
stats_to_number, dict_to_hash, nested_access, is_string_list,
6+
avg_split_string_list_under_limit, is_float
7+
)
8+
9+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
10+
11+
class CommonUtilsTest(DataJuicerTestCaseBase):
12+
13+
def test_stats_to_number(self):
14+
self.assertEqual(stats_to_number('1.0'), 1.0)
15+
self.assertEqual(stats_to_number([1.0, 2.0, 3.0]), 2.0)
16+
17+
self.assertEqual(stats_to_number([]), -sys.maxsize)
18+
self.assertEqual(stats_to_number(None), -sys.maxsize)
19+
self.assertEqual(stats_to_number([], reverse=False), sys.maxsize)
20+
self.assertEqual(stats_to_number(None, reverse=False), sys.maxsize)
21+
22+
def test_dict_to_hash(self):
23+
self.assertEqual(len(dict_to_hash({'a': 1, 'b': 2})), 64)
24+
self.assertEqual(len(dict_to_hash({'a': 1, 'b': 2}, hash_length=32)), 32)
25+
26+
def test_nested_access(self):
27+
self.assertEqual(nested_access({'a': {'b': 1}}, 'a.b'), 1)
28+
self.assertEqual(nested_access({'a': [{'b': 1}]}, 'a.0.b', digit_allowed=True), 1)
29+
self.assertEqual(nested_access({'a': [{'b': 1}]}, 'a.0.b', digit_allowed=False), None)
30+
31+
def test_is_string_list(self):
32+
self.assertTrue(is_string_list(['a', 'b', 'c']))
33+
self.assertFalse(is_string_list([1, 2, 3]))
34+
self.assertFalse(is_string_list(['a', 2, 'c']))
35+
36+
def test_is_float(self):
37+
self.assertTrue(is_float('1.0'))
38+
self.assertTrue(is_float(1.0))
39+
self.assertTrue(is_float('1e-4'))
40+
self.assertFalse(is_float('a'))
41+
42+
def test_avg_split_string_list_under_limit(self):
43+
test_data = [
44+
(['a', 'b', 'c'], [1, 2, 3], None, [['a', 'b', 'c']]),
45+
(['a', 'b', 'c'], [1, 2, 3], 3, [['a', 'b'], ['c']]),
46+
(['a', 'b', 'c'], [1, 2, 3], 2, [['a'], ['b'], ['c']]),
47+
(['a', 'b', 'c', 'd', 'e'], [1, 2, 3, 1, 1], 3, [['a', 'b'], ['c'], ['d', 'e']]),
48+
(['a', 'b', 'c'], [1, 2], 3, [['a', 'b', 'c']]),
49+
(['a', 'b', 'c'], [1, 2, 3], 100, [['a', 'b', 'c']]),
50+
]
51+
52+
for str_list, token_nums, max_token_num, expected_result in test_data:
53+
self.assertEqual(avg_split_string_list_under_limit(str_list, token_nums, max_token_num), expected_result)
54+
55+
56+
if __name__ == '__main__':
57+
unittest.main()

0 commit comments

Comments
 (0)