|
| 1 | +import os |
| 2 | +import json |
| 3 | +import unittest |
| 4 | + |
| 5 | +from data_juicer.utils.asset_utils import load_words_asset |
| 6 | + |
| 7 | +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase |
| 8 | + |
| 9 | +class LoadWordsAssetTest(DataJuicerTestCaseBase): |
| 10 | + |
| 11 | + def setUp(self) -> None: |
| 12 | + self.temp_output_path = 'tmp/test_asset_utils/' |
| 13 | + |
| 14 | + def tearDown(self): |
| 15 | + if os.path.exists(self.temp_output_path): |
| 16 | + os.system(f'rm -rf {self.temp_output_path}') |
| 17 | + |
| 18 | + def test_basic_func(self): |
| 19 | + # download assets from the remote server |
| 20 | + words_dict = load_words_asset(self.temp_output_path, 'stopwords') |
| 21 | + self.assertTrue(len(words_dict) > 0) |
| 22 | + self.assertTrue(os.path.exists(os.path.join(self.temp_output_path, 'stopwords.json'))) |
| 23 | + |
| 24 | + words_dict = load_words_asset(self.temp_output_path, 'flagged_words') |
| 25 | + self.assertTrue(len(words_dict) > 0) |
| 26 | + self.assertTrue(os.path.exists(os.path.join(self.temp_output_path, 'flagged_words.json'))) |
| 27 | + |
| 28 | + # non-existing asset |
| 29 | + with self.assertRaises(ValueError): |
| 30 | + load_words_asset(self.temp_output_path, 'non_existing_asset') |
| 31 | + |
| 32 | + def test_load_from_existing_file(self): |
| 33 | + os.makedirs(self.temp_output_path, exist_ok=True) |
| 34 | + temp_asset = os.path.join(self.temp_output_path, 'temp_asset.json') |
| 35 | + with open(temp_asset, 'w') as fout: |
| 36 | + json.dump({'test_key': ['test_val']}, fout) |
| 37 | + |
| 38 | + words_list = load_words_asset(self.temp_output_path, 'temp_asset') |
| 39 | + self.assertEqual(len(words_list), 1) |
| 40 | + self.assertEqual(len(words_list['test_key']), 1) |
| 41 | + |
| 42 | + def test_load_from_serial_files(self): |
| 43 | + os.makedirs(self.temp_output_path, exist_ok=True) |
| 44 | + temp_asset = os.path.join(self.temp_output_path, 'temp_asset_v1.json') |
| 45 | + with open(temp_asset, 'w') as fout: |
| 46 | + json.dump({'test_key': ['test_val_1']}, fout) |
| 47 | + temp_asset = os.path.join(self.temp_output_path, 'temp_asset_v2.json') |
| 48 | + with open(temp_asset, 'w') as fout: |
| 49 | + json.dump({'test_key': ['test_val_2']}, fout) |
| 50 | + |
| 51 | + words_list = load_words_asset(self.temp_output_path, 'temp_asset') |
| 52 | + self.assertEqual(len(words_list), 1) |
| 53 | + self.assertEqual(len(words_list['test_key']), 2) |
| 54 | + |
| 55 | + |
| 56 | +if __name__ == '__main__': |
| 57 | + unittest.main() |
0 commit comments