diff --git a/.gitignore b/.gitignore index b6d910918d..4acd0057e1 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,24 @@ FETCH_HEAD # vtk *.vtk *.vtu +<<<<<<< ours +======= + +# auto generated version file by setuptools_scm +ppsci/_version.py +# Model checkpoints +*.pdparams +*.pdopt +*.pdstates + +# Output directories +outputs_HumanoidControl/ + +# Screenshots and images +test_screenshot/ + +# Cache files +__pycache__/ +*.py[cod] +*.class +>>>>>>> theirs diff --git a/README.md b/README.md index ff9dc4f9f2..bb6401b5e0 100644 --- a/README.md +++ b/README.md @@ -54,15 +54,15 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 # write your code here... ``` -更多安装方式请参考 [**安装与使用**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/install_setup/) +更多安装方式请参考 [**安装与使用**](https://paddlescience-docs.readthedocs.io/zh/release-1.0/zh/install_setup/) ## 快速开始 -请参考 [**快速开始**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/quickstart/) +请参考 [**快速开始**](https://paddlescience-docs.readthedocs.io/zh/release-1.0/zh/quickstart/) ## 经典案例 -请参考 [**经典案例**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/viv/) +请参考 [**经典案例**](https://paddlescience-docs.readthedocs.io/zh/release-1.0/zh/examples/viv/) ## 支持 @@ -70,7 +70,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 ## 贡献代码 -PaddleScience 项目欢迎并依赖开发人员和开源社区中的用户,请参阅 [**贡献指南**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/contribute/)。 +PaddleScience 项目欢迎并依赖开发人员和开源社区中的用户,请参阅 [**贡献指南**](https://paddlescience-docs.readthedocs.io/zh/release-1.0/zh/contribute/)。 ## 证书 diff --git a/docs/images/overview/panorama.png b/docs/images/overview/panorama.png index 01d98652ea..64bac4cd31 100644 Binary files a/docs/images/overview/panorama.png and b/docs/images/overview/panorama.png differ diff --git a/docs/index.md b/docs/index.md index 1726766563..99a5de427f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -21,7 +21,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 ## 贡献代码 -PaddleScience 项目欢迎并依赖开发人员和开源社区中的用户,请参阅 [**贡献指南**](https://paddlescience-docs.readthedocs.io/zh/latest/zh/contribute/)。 +PaddleScience 项目欢迎并依赖开发人员和开源社区中的用户,请参阅 [**贡献指南**](https://paddlescience-docs.readthedocs.io/zh/release-1.0/zh/contribute/)。 ## 证书 diff --git a/docs/zh/examples/mujoco_control.md b/docs/zh/examples/mujoco_control.md new file mode 100644 index 0000000000..875051923a --- /dev/null +++ b/docs/zh/examples/mujoco_control.md @@ -0,0 +1,218 @@ +本项目使用PaddleScience和DeepMind Control Suite (dm_control) 实现了一个人形机器人(Humanoid)的运动控制系统。该系统通过深度学习方法,学习控制人形机器人进行稳定的运动。PINN(Physics-informed Neural Network)方法利用控制方程加速深度学习神经网络收敛,甚至在无训练数据的情况下实现无监督学习。尝试实现Humanoid控制仿真。 + +1. 开发指南 - PaddleScience Docs (paddlescience-docs.readthedocs.io) +2. google-deepmind/dm_control: Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo. (github.com) +pip install dm_control + +安装paddle cuda11.8 +python3 -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/ + +安装paddlescience +git clone -b develop https://github.com/PaddlePaddle/PaddleScience.git +### 若 github clone 速度比较慢,可以使用 gitee clone +### git clone -b develop https://gitee.com/paddlepaddle/PaddleScience.git +cd PaddleScience +### install paddlesci with editable mode +python -m pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple + +### MuJoCo Humanoid Control with PaddleScience + + +### 主要特点 +- 使用PaddleScience框架进行深度学习模型训练 +- 基于dm_control的MuJoCo物理引擎进行机器人仿真 +- 实现了自监督学习方案 +- 提供了完整的训练和评估流程 +- 包含详细的性能分析和可视化工具 + +## 项目结构 +PaddleScience/examples/ +``` +mujoco_control/ +├── conf/ +│ └── humanoid_control.yaml # 配置文件 +├── humanoid_complete.py # 主程序文件 +└── outputs_HumanoidControl/ # 输出目录 + └── YYYY-MM-DD/ # 按日期组织的输出 + ├── checkpoints/ # 模型检查点 + ├── evaluation/ # 评估结果 + └── logs/ # 训练日志 +``` +``` +── conf +│ └── humanoid_control.yaml +├── humanoid_complete.py +└── outputs_HumanoidControl + ├── 13-17-41 + │ └── mode=train + │ ├── checkpoints + │ │ ├── epoch_10.pdopt + │ │ ├── epoch_10.pdparams + │ │ ├── epoch_10.pdstates + │ │ ├── epoch_100.pdopt + │ │ ├── epoch_100.pdparams + │ │ ├── epoch_100.pdstates + │ │ ├── epoch_20.pdopt + │ │ ├── epoch_20.pdparams + │ │ ├── epoch_20.pdstates + │ │ ├── epoch_30.pdopt + │ │ ├── epoch_30.pdparams + │ │ ├── epoch_30.pdstates + │ │ ├── epoch_40.pdopt + │ │ ├── epoch_40.pdparams + │ │ ├── epoch_40.pdstates + │ │ ├── epoch_50.pdopt + │ │ ├── epoch_50.pdparams + │ │ ├── epoch_50.pdstates + │ │ ├── epoch_60.pdopt + │ │ ├── epoch_60.pdparams + │ │ ├── epoch_60.pdstates + │ │ ├── epoch_70.pdopt + │ │ ├── epoch_70.pdparams + │ │ ├── epoch_70.pdstates + │ │ ├── epoch_80.pdopt + │ │ ├── epoch_80.pdparams + │ │ ├── epoch_80.pdstates + │ │ ├── epoch_90.pdopt + │ │ ├── epoch_90.pdparams + │ │ ├── epoch_90.pdstates + │ │ ├── latest.pdopt + │ │ ├── latest.pdparams + │ │ └── latest.pdstates + │ └── train.log +``` +## 核心组件 + +### 1. 数据集类 (HumanoidDataset) +```python +class HumanoidDataset: + """处理训练数据的收集和预处理""" + def __init__(self, num_episodes=1000, episode_length=1000, ratio_split=0.8) + def collect_episode_data(self) # 收集单个回合数据 + def _flatten_observation(self) # 处理观察数据 + def generate_dataset(self) # 生成训练集和验证集 +``` + +### 2. 控制器模型 (HumanoidController) +```python +class HumanoidController(paddle.nn.Layer): + """神经网络控制器""" + def __init__(self, state_size, action_size, hidden_size=256) + def forward(self, x) # 前向传播,预测动作 +``` + +### 3. 评估器类 (HumanoidEvaluator) +```python +class HumanoidEvaluator: + """模型评估和可视化""" + def __init__(self, model_path, num_episodes=5, episode_length=1000) + def evaluate_episode(self) # 评估单个回合 + def run_evaluation(self) # 运行完整评估 +``` + +## 配置说明 + +主要配置参数(在humanoid_control.yaml中): + +```yaml +DATA: + num_episodes: 100 # 训练回合数 + episode_length: 500 # 每回合步数 + +MODEL: + hidden_size: 256 # 隐藏层大小 + +TRAIN: + epochs: 100 # 训练轮数 + batch_size: 32 # 批次大小 + learning_rate: 0.001 # 学习率 + +EVAL: + num_episodes: 5 # 评估回合数 + episode_length: 1000 # 评估步数长度 +``` + +## 训练流程 + +### 训练方法 +1. 数据收集: + - 使用随机策略收集初始训练数据 + - 将数据分割为训练集和验证集 + +2. 模型训练: + - 使用PaddleScience的训练框架 + - 实现了自定义损失函数 + - 包含动作预测和奖励最大化两个目标 + +3. 训练命令: +```bash +python humanoid_complete.py mode=train +``` + +### 评估方法 +1. 模型评估: + - 在真实环境中运行训练好的模型 + - 收集性能指标 + - 生成评估视频(如果可用) + +2. 评估命令: +```bash +python humanoid_complete.py mode=eval +EVAL.pretrained_model_path="path/to/checkpoint" +``` + +## 性能分析 + +评估过程会生成以下分析结果: +- 总体奖励统计 +- 动作模式分析 +- 性能可视化图表 +- 评估视频(如果启用) + +## 输出说明 + +### 训练输出 +- 模型检查点 +- 训练日志 +- 学习曲线 + +### 评估输出 +- 统计数据文件 (stats.txt) +- 性能分析图表 +- 评估视频文件(如果启用) + +## 使用示例 + +1. 训练新模型: +python humanoid_complete.py mode=train + +2. 评估已训练模型: +python humanoid_complete.py mode=eval + +## 注意事项 + +1. 环境要求: + - PaddlePaddle >= 3.0.0 + - dm_control + - MuJoCo物理引擎 + - Python >= 3.7 (测试环境为3.10.15) + +2. 性能优化建议: + - 适当调整batch_size和learning_rate + - 根据需要修改网络结构 + - 可以通过修改配置文件调整训练参数 + +3. 已知问题: + - WSL2环境下可能存在可视化问题 + - 需要使用适当的渲染后端 + +## 未来改进 + +1. 功能扩展: + - 添加更多控制策略 + - 实现多种任务场景 + - 增强可视化功能 + +2. 性能优化: + - 改进训练效率 + - 优化模型结构 + - 增加并行训练支持 diff --git a/docs/zh/install_setup.md b/docs/zh/install_setup.md index d35904f471..40e87317ee 100644 --- a/docs/zh/install_setup.md +++ b/docs/zh/install_setup.md @@ -26,25 +26,20 @@ pip install -r requirements.txt ``` - ???+ Info "安装注意事项" - - 如需使用外部导入STL文件来构建几何,以及使用加密采样等功能,还需额外安装三个依赖库: - [pymesh](https://pymesh.readthedocs.io/en/latest/installation.html#download-the-source)(推荐编译安装), - [open3d](https://github.com/isl-org/Open3D/tree/master#python-quick-start)(推荐pip安装), - [pysdf](https://github.com/sxyu/sdf)(推荐pip安装) - #### 1.2.2 pip 安装 -coming soon - -
+???+ Info "安装注意事项" + + 如需使用外部导入STL文件来构建几何,以及使用加密采样等功能,还需额外安装三个依赖库: + [pymesh](https://pymesh.readthedocs.io/en/latest/installation.html#download-the-source)(推荐编译安装), + [open3d](https://github.com/isl-org/Open3D/tree/master#python-quick-start)(推荐pip安装), + [pysdf](https://github.com/sxyu/sdf)(推荐pip安装) ## 2. 验证安装 diff --git a/examples/mujoco_control/conf/humanoid_control.yaml b/examples/mujoco_control/conf/humanoid_control.yaml new file mode 100644 index 0000000000..7e71a8dfd6 --- /dev/null +++ b/examples/mujoco_control/conf/humanoid_control.yaml @@ -0,0 +1,41 @@ +defaults: + - _self_ + +hydra: + run: + dir: outputs_HumanoidControl/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} + chdir: false + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: train +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 20 + +DATA: + num_episodes: 100 + episode_length: 500 + +MODEL: + hidden_size: 256 + +TRAIN: + epochs: 100 + iters_per_epoch: 10 + save_freq: 10 + learning_rate: 0.001 + batch_size: 32 + pretrained_model_path: null + checkpoint_path: null + eval_with_no_grad: true + +EVAL: + pretrained_model_path: outputs_HumanoidControl/2024-12-15/22-02-39/mode=train/checkpoints/latest.pdparams + eval_with_no_grad: true + num_episodes: 5 + episode_length: 1000 + interactive: false diff --git a/examples/mujoco_control/humanoid_complete.py b/examples/mujoco_control/humanoid_complete.py new file mode 100644 index 0000000000..415c6e1cbb --- /dev/null +++ b/examples/mujoco_control/humanoid_complete.py @@ -0,0 +1,441 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +from os import path as osp +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import paddle +from dm_control import suite +from omegaconf import DictConfig + +import ppsci +from ppsci.utils import logger + + +class HumanoidDataset: + def __init__(self, num_episodes=1000, episode_length=1000, ratio_split=0.8): + self.env = suite.load(domain_name="humanoid", task_name="run") + self.num_episodes = num_episodes + self.episode_length = episode_length + self.ratio_split = ratio_split + + def collect_episode_data(self): + """Collect single episode data""" + states, actions, rewards = [], [], [] + time_step = self.env.reset() + + # Get action specification for random sampling + action_spec = self.env.action_spec() + + for _ in range(self.episode_length): + action = np.random.uniform( + action_spec.minimum, action_spec.maximum, size=action_spec.shape + ) + + states.append(self._flatten_observation(time_step.observation)) + actions.append(action) + + time_step = self.env.step(action) + rewards.append(time_step.reward if time_step.reward is not None else 0.0) + + if time_step.last(): + break + + return np.array(states), np.array(actions), np.array(rewards) + + def _flatten_observation(self, observation): + """Flatten observation dict to array""" + return np.concatenate([v.flatten() for v in observation.values()]) + + def generate_dataset(self): + all_states, all_actions, all_rewards = [], [], [] + + print("Collecting training data...") + for i in range(self.num_episodes): + if i % 10 == 0: + print(f"Episode {i}/{self.num_episodes}") + states, actions, rewards = self.collect_episode_data() + all_states.append(states) + all_actions.append(actions) + all_rewards.append(rewards) + + states = np.array(all_states) + actions = np.array(all_actions) + rewards = np.array(all_rewards) + + split_idx = int(self.num_episodes * self.ratio_split) + + train_data = { + "input": {"state": states[:split_idx].reshape(-1, states.shape[-1])}, + "label": { + "action": actions[:split_idx].reshape(-1, actions.shape[-1]), + "reward": rewards[:split_idx].reshape(-1, 1), + }, + } + + val_data = { + "input": {"state": states[split_idx:].reshape(-1, states.shape[-1])}, + "label": { + "action": actions[split_idx:].reshape(-1, actions.shape[-1]), + "reward": rewards[split_idx:].reshape(-1, 1), + }, + } + + return train_data, val_data + + +class HumanoidController(paddle.nn.Layer): + def __init__(self, state_size, action_size, hidden_size=256): + super().__init__() + self.net = paddle.nn.Sequential( + paddle.nn.Linear(state_size, hidden_size), + paddle.nn.ReLU(), + paddle.nn.Linear(hidden_size, hidden_size), + paddle.nn.ReLU(), + paddle.nn.Linear(hidden_size, action_size), + paddle.nn.Tanh(), + ) + + def forward(self, x): + state = paddle.to_tensor(x["state"], dtype="float32") + return {"action": self.net(state)} + + +def train_loss_func(output_dict, label_dict, weight_dict=None): + """Calculate training loss with properly named components""" + # Predict next state and maximize reward + action_loss = paddle.mean( + paddle.square(output_dict["action"] - label_dict["action"]) + ) + reward_loss = -paddle.mean( + label_dict["reward"] + ) # Negative since we want to maximize reward + total_loss = action_loss + 0.1 * reward_loss + return {"total_loss": total_loss} + + +def metric_eval(output_dict, label_dict=None, weight_dict=None): + """Simple metric function that returns a single scalar value""" + # Use the same calculation as training loss + action_loss = float( + paddle.mean(paddle.square(output_dict["action"] - label_dict["action"])) + ) + reward_loss = float(-paddle.mean(label_dict["reward"])) + total_loss = action_loss + 0.1 * reward_loss + + # Return a single scalar metric + return {"val_loss": total_loss} + + +def train(cfg: DictConfig): + # Set random seed + ppsci.utils.misc.set_random_seed(cfg.seed) + logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") + + # Generate dataset + dataset = HumanoidDataset( + num_episodes=cfg.DATA.num_episodes, episode_length=cfg.DATA.episode_length + ) + train_data, val_data = dataset.generate_dataset() + + # Initialize model + state_size = train_data["input"]["state"].shape[-1] + action_size = train_data["label"]["action"].shape[-1] + model = HumanoidController(state_size, action_size, cfg.MODEL.hidden_size) + + # Convert data to float32 + train_data["input"]["state"] = train_data["input"]["state"].astype("float32") + train_data["label"]["action"] = train_data["label"]["action"].astype("float32") + train_data["label"]["reward"] = train_data["label"]["reward"].astype("float32") + + val_data["input"]["state"] = val_data["input"]["state"].astype("float32") + val_data["label"]["action"] = val_data["label"]["action"].astype("float32") + val_data["label"]["reward"] = val_data["label"]["reward"].astype("float32") + + # Create training constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "NamedArrayDataset", + "input": train_data["input"], + "label": train_data["label"], + }, + "batch_size": cfg.TRAIN.batch_size, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": True, + }, + }, + ppsci.loss.FunctionalLoss(train_loss_func), + {"action": lambda out: out["action"]}, + name="sup_train", + ) + + # Create validator + # In your train function, update the validator creation: + sup_validator = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": val_data["input"], + "label": val_data["label"], + }, + "batch_size": cfg.TRAIN.batch_size, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": False, + }, + }, + ppsci.loss.FunctionalLoss(train_loss_func), + {"action": lambda out: out["action"]}, + metric={"metric": ppsci.metric.FunctionalMetric(metric_eval)}, + name="sup_valid", + ) + + # Initialize optimizer and solver + optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model) + solver = ppsci.solver.Solver( + model, + {sup_constraint.name: sup_constraint}, + cfg.output_dir, + optimizer, + None, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + validator={sup_validator.name: sup_validator}, + ) + + solver.train() + solver.eval() + + +class HumanoidEvaluator: + def __init__(self, model_path, num_episodes=5, episode_length=1000): + self.env = suite.load(domain_name="humanoid", task_name="run") + self.model_path = model_path + self.num_episodes = num_episodes + self.episode_length = episode_length + self.load_model() + + def load_model(self): + time_step = self.env.reset() + state_size = sum(v.size for v in time_step.observation.values()) + action_spec = self.env.action_spec() + action_size = action_spec.shape[0] + + self.model = HumanoidController(state_size, action_size) + state_dict = paddle.load(self.model_path) + self.model.set_state_dict(state_dict) + self.model.eval() + + def _flatten_observation(self, observation): + return np.concatenate([v.flatten() for v in observation.values()]) + + def evaluate_episode(self, record_video=False): + """Evaluate single episode and collect detailed data""" + import mujoco + + time_step = self.env.reset() + total_reward = 0 + frames = [] + + # Data collection lists + episode_data = {"rewards": [], "actions": [], "com_velocity": []} + + # Setup offscreen renderer if recording video + if record_video: + width, height = 640, 480 + renderer = mujoco.Renderer(self.env.physics.model, width, height) + + for t in range(self.episode_length): + # Get state and predict action + state = self._flatten_observation(time_step.observation) + state_tensor = {"state": paddle.to_tensor(state[None, :], dtype="float32")} + + with paddle.no_grad(): + action = self.model(state_tensor)["action"].numpy()[0] + + # Render frame if recording + if record_video: + renderer.update_scene(self.env.physics.data) + pixels = renderer.render() + frames.append(pixels) + + # Take step + time_step = self.env.step(action) + reward = time_step.reward if time_step.reward is not None else 0 + + # Collect data + episode_data["rewards"].append(reward) + episode_data["actions"].append(action) + episode_data["com_velocity"].append(time_step.observation["velocity"]) + + total_reward += reward + + if time_step.last(): + break + + # Convert lists to numpy arrays + episode_data["rewards"] = np.array(episode_data["rewards"]) + episode_data["actions"] = np.array(episode_data["actions"]) + episode_data["com_velocity"] = np.array(episode_data["com_velocity"]) + + return total_reward, frames, episode_data + + def evaluate(self, save_dir="./evaluation_results"): + """Run full evaluation with multiple episodes and generate analysis""" + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + rewards = [] + all_episode_data = [] + logger.info("\nStarting evaluation...") + + for ep in range(self.num_episodes): + logger.info(f"\nEpisode {ep + 1}/{self.num_episodes}") + # Record video for first and last episodes + record_video = ep == 0 or ep == self.num_episodes - 1 + reward, frames, episode_data = self.evaluate_episode(record_video) + rewards.append(reward) + all_episode_data.append(episode_data) + logger.info(f"Episode reward: {reward:.2f}") + + # Save video if frames were recorded + if record_video and frames: + import imageio + + video_path = save_dir / f"episode_{ep+1}.mp4" + imageio.mimsave(video_path, frames, fps=30) + logger.info(f"Saved video to {video_path}") + + # Generate analysis and save statistics + self._generate_analysis(rewards, all_episode_data, save_dir) + self._save_statistics(rewards, all_episode_data, save_dir) + + logger.info("\nEvaluation completed!") + logger.info(f"Mean reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}") + return rewards + + def _generate_analysis(self, rewards, all_episode_data, save_dir): + """Generate comprehensive analysis plots""" + plt.figure(figsize=(15, 10)) + + plt.subplot(2, 2, 1) + plt.plot(rewards, "b-o") + plt.title("Episode Rewards") + plt.xlabel("Episode") + plt.ylabel("Total Reward") + plt.grid(True) + + plt.subplot(2, 2, 2) + plt.hist(rewards, bins=min(len(rewards), 10), color="blue", alpha=0.7) + plt.axvline(np.mean(rewards), color="r", linestyle="--", label="Mean") + plt.title("Reward Distribution") + plt.xlabel("Reward") + plt.ylabel("Frequency") + plt.legend() + + plt.subplot(2, 2, 3) + for i in range(min(3, len(all_episode_data))): + plt.plot(all_episode_data[i]["rewards"], label=f"Episode {i+1}", alpha=0.7) + plt.title("Reward Trajectories") + plt.xlabel("Step") + plt.ylabel("Reward") + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 4) + for i in range(min(3, len(all_episode_data))): + vel = all_episode_data[i]["com_velocity"] + speed = np.linalg.norm(vel, axis=1) if len(vel.shape) > 1 else np.abs(vel) + plt.plot(speed, label=f"Episode {i+1}", alpha=0.7) + plt.title("Center of Mass Speed") + plt.xlabel("Step") + plt.ylabel("Speed") + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig(save_dir / "performance_analysis.png") + plt.close() + + def _save_statistics(self, rewards, all_episode_data, save_dir): + """Save detailed statistical analysis""" + with open(save_dir / "detailed_stats.txt", "w") as f: + f.write("=== Humanoid Evaluation Statistics ===\n\n") + + # Episode Statistics + f.write("Episode Statistics:\n") + f.write(f"Number of episodes: {len(rewards)}\n") + f.write(f"Mean reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\n") + f.write(f"Max reward: {np.max(rewards):.2f}\n") + f.write(f"Min reward: {np.min(rewards):.2f}\n\n") + + # Action Statistics + all_actions = np.concatenate( + [ep_data["actions"] for ep_data in all_episode_data] + ) + f.write("Action Statistics:\n") + f.write(f"Mean action magnitude: {np.mean(np.abs(all_actions)):.3f}\n") + f.write(f"Max action magnitude: {np.max(np.abs(all_actions)):.3f}\n") + f.write(f"Action std: {np.std(all_actions):.3f}\n\n") + + # Movement Statistics + f.write("Movement Statistics:\n") + for ep_idx, ep_data in enumerate(all_episode_data[:3]): # First 3 episodes + velocities = ep_data["com_velocity"] + speed = ( + np.linalg.norm(velocities, axis=1) + if len(velocities.shape) > 1 + else np.abs(velocities) + ) + f.write(f"Episode {ep_idx + 1}:\n") + f.write(f" Mean speed: {np.mean(speed):.3f}\n") + f.write(f" Max speed: {np.max(speed):.3f}\n") + f.write(f" Distance covered: {np.sum(speed):.3f}\n\n") + + +def evaluate(cfg: DictConfig): + """Evaluate trained humanoid controller""" + # Initialize evaluator with trained model + evaluator = HumanoidEvaluator( + model_path=cfg.EVAL.pretrained_model_path, + num_episodes=cfg.EVAL.num_episodes, + episode_length=cfg.EVAL.episode_length, + ) + + # Create evaluation output directory + eval_dir = Path(cfg.output_dir) / "evaluation_results" + eval_dir.mkdir(parents=True, exist_ok=True) + + # Run evaluation + logger.info("Starting evaluation...") + rewards = evaluator.evaluate( + save_dir=eval_dir + ) # Changed from run_evaluation to evaluate + + return rewards + + +@hydra.main( + version_base=None, config_path="./conf", config_name="humanoid_control.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError( + f"cfg.mode should be in ['train', 'eval'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/ppsci/metric/anomaly_coef.py b/ppsci/metric/anomaly_coef.py index 4faf2ac9b1..a94af65edc 100644 --- a/ppsci/metric/anomaly_coef.py +++ b/ppsci/metric/anomaly_coef.py @@ -32,7 +32,7 @@ class LatitudeWeightedACC(base.Metric): $$ $$ - L_m = N_{lat}\dfrac{cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}cos(lat_j)} + L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)} $$ $lat_m$ is the latitude at m. diff --git a/ppsci/metric/l2_rel.py b/ppsci/metric/l2_rel.py index 501b7f9de9..80fe68a13e 100644 --- a/ppsci/metric/l2_rel.py +++ b/ppsci/metric/l2_rel.py @@ -21,7 +21,11 @@ class L2Rel(base.Metric): r"""Class for l2 relative error. $$ - metric = \dfrac{\Vert x-y \Vert_2}{\Vert y \Vert_2} + metric = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\Vert \mathbf{y} \Vert_2} + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: diff --git a/ppsci/metric/mae.py b/ppsci/metric/mae.py index 54f814aa1f..cd8e28f9ba 100644 --- a/ppsci/metric/mae.py +++ b/ppsci/metric/mae.py @@ -22,7 +22,11 @@ class MAE(base.Metric): r"""Mean absolute error. $$ - metric = \dfrac{1}{N}\sum\limits_{i=1}^{N}{|x_i-y_i|} + metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_1 + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: diff --git a/ppsci/metric/mse.py b/ppsci/metric/mse.py index 979ca3c207..b0a1ae6ad4 100644 --- a/ppsci/metric/mse.py +++ b/ppsci/metric/mse.py @@ -22,7 +22,11 @@ class MSE(base.Metric): r"""Mean square error $$ - metric = \dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2} + metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_2^2 + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: diff --git a/ppsci/metric/rmse.py b/ppsci/metric/rmse.py index 7dff093432..8b2c7a7b08 100644 --- a/ppsci/metric/rmse.py +++ b/ppsci/metric/rmse.py @@ -28,7 +28,11 @@ class RMSE(base.Metric): r"""Root mean square error $$ - metric = \sqrt{\dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2}} + metric = \sqrt{\dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_2^2} + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: @@ -62,7 +66,7 @@ class LatitudeWeightedRMSE(base.Metric): $$ $$ - L_m = N_{lat}\dfrac{cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}cos(lat_j)} + L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)} $$ $lat_m$ is the latitude at m. diff --git a/ppsci/utils/expression.py b/ppsci/utils/expression.py index 33116e35a6..80fe50b8ef 100644 --- a/ppsci/utils/expression.py +++ b/ppsci/utils/expression.py @@ -166,7 +166,7 @@ def visu_forward( Args: expr_dict (Optional[Dict[str, Callable]]): Expression dict. - input_dict (Dict[str, paddle.Tensor]]): Input dict. + input_dict (Dict[str, paddle.Tensor]): Input dict. model (nn.Layer): NN model. Returns: diff --git a/ppsci/visualize/plot.py b/ppsci/visualize/plot.py index 7aa90d91ad..827847691e 100644 --- a/ppsci/visualize/plot.py +++ b/ppsci/visualize/plot.py @@ -165,7 +165,7 @@ def _save_plot_from_2d_array( Args: filename (str): Filename. visu_data (Tuple[np.ndarray, ...]): Data that requires visualization. - visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v"). + visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v"). num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1. stride (int, optional): The time stride of visualization. Defaults to 1. xticks (Optional[Tuple[float, ...]]): Tuple of xtick locations. Defaults to None. @@ -314,7 +314,7 @@ def _save_plot_from_3d_array( Args: filename (str): Filename. visu_data (Tuple[np.ndarray, ...]): Data that requires visualization. - visu_keys (Tuple[str, ...]]): Keys for visualizing data. such as ("u", "v"). + visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v"). num_timestamps (int, optional): Number of timestamps coord/value contains. Defaults to 1. """