|
1 |
| -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
2 |
| - |
3 |
| -# Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
| -# you may not use this file except in compliance with the License. |
5 |
| -# You may obtain a copy of the License at |
6 |
| - |
7 |
| -# http://www.apache.org/licenses/LICENSE-2.0 |
8 |
| - |
9 |
| -# Unless required by applicable law or agreed to in writing, software |
10 |
| -# distributed under the License is distributed on an "AS IS" BASIS, |
11 |
| -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
| -# See the License for the specific language governing permissions and |
13 |
| -# limitations under the License. |
| 1 | +import os |
14 | 2 |
|
15 | 3 | import hydra
|
16 | 4 | import paddle
|
17 | 5 | import pytest
|
18 |
| -from omegaconf import DictConfig |
| 6 | +import yaml |
19 | 7 |
|
20 |
| -paddle.seed(1024) |
| 8 | +# 假设你的回调类在这个路径下 |
| 9 | +from ppsci.utils.callbacks import InitCallback |
21 | 10 |
|
| 11 | +# 设置 Paddle 的 seed |
| 12 | +paddle.seed(1024) |
22 | 13 |
|
| 14 | +# 测试函数不需要装饰器 |
23 | 15 | @pytest.mark.parametrize(
|
24 | 16 | "epochs,mode,seed",
|
25 | 17 | [
|
|
28 | 20 | (10, "eval", -1),
|
29 | 21 | ],
|
30 | 22 | )
|
31 |
| -def test_invalid_epochs( |
32 |
| - epochs, |
33 |
| - mode, |
34 |
| - seed, |
35 |
| -): |
36 |
| - @hydra.main(version_base=None, config_path="./", config_name="test_config.yaml") |
37 |
| - def main(cfg: DictConfig): |
38 |
| - pass |
39 |
| - |
40 |
| - # sys.exit will be called when validation error in pydantic, so there we use |
41 |
| - # SystemExit instead of other type of errors. |
42 |
| - with pytest.raises(SystemExit): |
43 |
| - cfg_dict = dict( |
44 |
| - { |
45 |
| - "TRAIN": { |
46 |
| - "epochs": epochs, |
47 |
| - }, |
48 |
| - "mode": mode, |
49 |
| - "seed": seed, |
50 |
| - "hydra": { |
51 |
| - "callbacks": { |
52 |
| - "init_callback": { |
53 |
| - "_target_": "ppsci.utils.callbacks.InitCallback" |
54 |
| - } |
55 |
| - } |
56 |
| - }, |
| 23 | +def test_invalid_epochs(tmpdir, epochs, mode, seed): |
| 24 | + cfg_dict = { |
| 25 | + "hydra": { |
| 26 | + "callbacks": { |
| 27 | + "init_callback": {"_target_": "ppsci.utils.callbacks.InitCallback"} |
57 | 28 | }
|
58 |
| - ) |
59 |
| - # print(cfg_dict) |
60 |
| - import yaml |
61 |
| - |
62 |
| - with open("test_config.yaml", "w") as f: |
63 |
| - yaml.dump(dict(cfg_dict), f) |
64 |
| - |
65 |
| - main() |
66 |
| - |
67 |
| - |
| 29 | + }, |
| 30 | + "mode": mode, |
| 31 | + "seed": seed, |
| 32 | + "TRAIN": { |
| 33 | + "epochs": epochs, |
| 34 | + }, |
| 35 | + } |
| 36 | + # 创建一个临时的配置文件 |
| 37 | + dir_ = os.path.dirname(__file__) |
| 38 | + config_abs_path = os.path.join(dir_, "test_config.yaml") |
| 39 | + with open(config_abs_path, "w") as f: |
| 40 | + f.write(yaml.dump(cfg_dict)) |
| 41 | + |
| 42 | + # 使用 hydra 的 compose API 来创建配置,而不是使用 main |
| 43 | + with hydra.initialize(config_path="./", version_base=None): |
| 44 | + cfg = hydra.compose(config_name="test_config.yaml") |
| 45 | + # 手动触发回调 |
| 46 | + with pytest.raises(SystemExit) as exec_info: |
| 47 | + InitCallback().on_job_start(config=cfg) |
| 48 | + assert exec_info.value.code == 2 |
| 49 | + # 你现在可以根据需要对 cfg 进行断言或进一步处理 |
| 50 | + |
| 51 | + |
| 52 | +# 这部分通常不需要,除非你想直接从脚本运行测试 |
68 | 53 | if __name__ == "__main__":
|
69 | 54 | pytest.main()
|
0 commit comments