Skip to content

Commit d319e80

Browse files
update corresponding unitests
1 parent 2fea541 commit d319e80

File tree

2 files changed

+41
-58
lines changed

2 files changed

+41
-58
lines changed

test/utils/test_config.py

+36-51
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2-
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
1+
import os
142

153
import hydra
164
import paddle
175
import pytest
18-
from omegaconf import DictConfig
6+
import yaml
197

20-
paddle.seed(1024)
8+
# 假设你的回调类在这个路径下
9+
from ppsci.utils.callbacks import InitCallback
2110

11+
# 设置 Paddle 的 seed
12+
paddle.seed(1024)
2213

14+
# 测试函数不需要装饰器
2315
@pytest.mark.parametrize(
2416
"epochs,mode,seed",
2517
[
@@ -28,42 +20,35 @@
2820
(10, "eval", -1),
2921
],
3022
)
31-
def test_invalid_epochs(
32-
epochs,
33-
mode,
34-
seed,
35-
):
36-
@hydra.main(version_base=None, config_path="./", config_name="test_config.yaml")
37-
def main(cfg: DictConfig):
38-
pass
39-
40-
# sys.exit will be called when validation error in pydantic, so there we use
41-
# SystemExit instead of other type of errors.
42-
with pytest.raises(SystemExit):
43-
cfg_dict = dict(
44-
{
45-
"TRAIN": {
46-
"epochs": epochs,
47-
},
48-
"mode": mode,
49-
"seed": seed,
50-
"hydra": {
51-
"callbacks": {
52-
"init_callback": {
53-
"_target_": "ppsci.utils.callbacks.InitCallback"
54-
}
55-
}
56-
},
23+
def test_invalid_epochs(tmpdir, epochs, mode, seed):
24+
cfg_dict = {
25+
"hydra": {
26+
"callbacks": {
27+
"init_callback": {"_target_": "ppsci.utils.callbacks.InitCallback"}
5728
}
58-
)
59-
# print(cfg_dict)
60-
import yaml
61-
62-
with open("test_config.yaml", "w") as f:
63-
yaml.dump(dict(cfg_dict), f)
64-
65-
main()
66-
67-
29+
},
30+
"mode": mode,
31+
"seed": seed,
32+
"TRAIN": {
33+
"epochs": epochs,
34+
},
35+
}
36+
# 创建一个临时的配置文件
37+
dir_ = os.path.dirname(__file__)
38+
config_abs_path = os.path.join(dir_, "test_config.yaml")
39+
with open(config_abs_path, "w") as f:
40+
f.write(yaml.dump(cfg_dict))
41+
42+
# 使用 hydra 的 compose API 来创建配置,而不是使用 main
43+
with hydra.initialize(config_path="./", version_base=None):
44+
cfg = hydra.compose(config_name="test_config.yaml")
45+
# 手动触发回调
46+
with pytest.raises(SystemExit) as exec_info:
47+
InitCallback().on_job_start(config=cfg)
48+
assert exec_info.value.code == 2
49+
# 你现在可以根据需要对 cfg 进行断言或进一步处理
50+
51+
52+
# 这部分通常不需要,除非你想直接从脚本运行测试
6853
if __name__ == "__main__":
6954
pytest.main()

test/utils/test_writer.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@
2121

2222
def test_save_csv_file():
2323
keys = ["x1", "y1", "z1"]
24-
alias_dict = (
25-
{
26-
"x": "x1",
27-
"y": "y1",
28-
"z": "z1",
29-
},
30-
)
24+
alias_dict = {
25+
"x": "x1",
26+
"y": "y1",
27+
"z": "z1",
28+
}
3129
data_dict = {
3230
keys[0]: np.random.randint(0, 255, (10, 1)),
3331
keys[1]: np.random.rand(10, 1),

0 commit comments

Comments
 (0)