diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index b1623cd20..5a373e921 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -11,6 +11,7 @@ - [Simulator](simulator.md) - [Interactive scenario editor](scene-editor.md) - [Visualizer](visualizer.md) +- [Export model to ONNX](export-onnx.md) # Data diff --git a/docs/src/export-onnx.md b/docs/src/export-onnx.md new file mode 100644 index 000000000..a08ffb05f --- /dev/null +++ b/docs/src/export-onnx.md @@ -0,0 +1,81 @@ +# Exporting PufferDrive Models to ONNX + +PufferDrive provides a utility script to export trained PyTorch models to the ONNX format. This is useful for deployment, inference optimization, or using the model in environments that support ONNX Runtime. + +## Usage + +The export script is located at `scripts/export_onnx.py`. You can run it from the root of the repository. + +### Basic Usage + +To export a model using default settings (assuming you have a checkpoint at the default path or specify one): + +```bash +python scripts/export_onnx.py --checkpoint path/to/your/checkpoint.pt +``` + +This will create an `.onnx` file in the same directory as the checkpoint, with the same name (e.g., `checkpoint.onnx`). + +### Specifying Output Path + +You can specify a custom output path for the ONNX file: + +```bash +python scripts/export_onnx.py \ + --checkpoint experiments/my_experiment/model_000100.pt \ + --output models/my_model.onnx +``` + +### Specifying Environment + +If you are using a specific environment configuration, you can specify it with `--env`: + +```bash +python scripts/export_onnx.py --env puffer_drive --checkpoint ... +``` + +### ONNX Opset Version + +You can specify the ONNX opset version (default is 18): + +```bash +python scripts/export_onnx.py --opset 17 ... +``` + +## Arguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--env` | str | `puffer_drive` | The environment name to load configuration for. | +| `--checkpoint` | str | (required/default example path) | Path to the PyTorch `.pt` checkpoint file. | +| `--output` | str | `None` (derived from checkpoint) | Path where the `.onnx` file will be saved. | +| `--opset` | int | `18` | ONNX opset version to use for export. | + +## Verification + +The script automatically verifies the exported ONNX model by running a forward pass on both the PyTorch model and the ONNX model with dummy inputs and comparing the outputs. It checks for: +- Logits +- Value +- LSTM hidden states (if applicable) + +If verification passes, it will print match confirmations. If there are mismatches, it will raise an error or print a mismatch warning. + +# Exporting Model Weights to .bin + +You can also export the model weights to a binary format (`.bin`) which can be loaded by the C backend of PufferDrive. This is done using `scripts/export_model_bin.py`. + +## Usage + +```bash +python scripts/export_model_bin.py --checkpoint path/to/your/checkpoint.pt +``` + +## Arguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--env` | str | `puffer_drive` | The environment name to load configuration for. | +| `--checkpoint` | str | (required) | Path to the PyTorch `.pt` checkpoint file. | +| `--output` | str | `pufferlib/resources/drive/model_puffer_drive_000100.bin` | Output path for the binary weights file. | + +This script flattens all model parameters into a single contiguous binary file. diff --git a/scripts/export_model_bin.py b/scripts/export_model_bin.py new file mode 100644 index 000000000..5abfbffee --- /dev/null +++ b/scripts/export_model_bin.py @@ -0,0 +1,134 @@ +import argparse +import os +import torch +import importlib +import numpy as np + +import pufferlib.utils +import pufferlib.vector +import pufferlib.models + +from pufferlib.ocean.torch import Drive + + +def load_config(env_name, config_dir=None): + # Minimal config loader based on pufferl.py + import configparser + import glob + from collections import defaultdict + import ast + + if config_dir is None: + puffer_dir = os.path.dirname(os.path.realpath(pufferlib.__file__)) + else: + puffer_dir = config_dir + + puffer_config_dir = os.path.join(puffer_dir, "config/**/*.ini") + puffer_default_config = os.path.join(puffer_dir, "config/default.ini") + + found = False + for path in glob.glob(puffer_config_dir, recursive=True): + p = configparser.ConfigParser() + p.read([puffer_default_config, path]) + if env_name in p["base"]["env_name"].split(): + found = True + break + + if not found: + raise ValueError(f"No config for env_name {env_name}") + + def puffer_type(value): + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError): + return value + + args = defaultdict(dict) + for section in p.sections(): + for key in p[section]: + value = puffer_type(p[section][key]) + args[section][key] = value + + return args + + +# Export PufferDrive model weights to .bin +def export_weights(): + parser = argparse.ArgumentParser(description="Export PufferDrive model weights to .bin") + parser.add_argument("--env", type=str, default="puffer_drive", help="Environment name") + parser.add_argument( + "--checkpoint", + type=str, + help="Path to .pt checkpoint", + ) + parser.add_argument( + "--output", + type=str, + default="pufferlib/resources/drive/model_puffer_drive_000100.bin", + help="Output .bin file path", + ) + + args = parser.parse_args() + + # Load configuration + config = load_config(args.env) + + # Load environment to get observation/action spaces + package = config["base"]["package"] + module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" + env_module = importlib.import_module(module_name) + make_env = env_module.env_creator(args.env) + + # Use valid dummy env to initialize policy + # Ensure env args/kwargs are correctly passed as expected by make() + env_kwargs = config["env"] + + vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1) + + # Initialize Policy + print("Initializing Policy...") + policy = Drive(vecenv.driver_env, **config["policy"]) + + if config["base"]["rnn_name"]: + print("Wrapping with LSTM...") + policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) + + # Load Checkpoint + print(f"Loading checkpoint from {args.checkpoint}...") + checkpoint = torch.load(args.checkpoint, map_location="cpu") + + # Handle both full checkpoint dict and raw state dict + if isinstance(checkpoint, dict) and "agent_state_dict" in checkpoint: + state_dict = checkpoint["agent_state_dict"] + else: + state_dict = checkpoint + + # Strip compile prefixes + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod."): + new_state_dict[k[10:]] = v + else: + new_state_dict[k] = v + + policy.load_state_dict(new_state_dict) + policy.eval() + + # Export Weights + print(f"Exporting weights to {args.output}...") + weights = [] + total_params = 0 + for name, param in policy.named_parameters(): + param_flat = param.data.cpu().numpy().flatten() + weights.append(param_flat) + count = param_flat.size + print(f" {name}: {param.shape} -> {count} params") + total_params += count + + weights = np.concatenate(weights) + weights.tofile(args.output) + print(f"Success! Saved {len(weights)} weights ({total_params} params) to {args.output}") + + +if __name__ == "__main__": + export_weights() diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py new file mode 100644 index 000000000..c6cfb3777 --- /dev/null +++ b/scripts/export_onnx.py @@ -0,0 +1,251 @@ +import argparse +import os +import torch +import importlib +import numpy as np +import onnxruntime as ort + +import pufferlib.utils +import pufferlib.vector +import pufferlib.models + +from pufferlib.ocean.torch import Drive +from scripts.export_model_bin import load_config + + +class OnnxWrapper(torch.nn.Module): + def __init__(self, policy): + super().__init__() + self.policy = policy + + def forward(self, observation, h, c): + # Reconstruct the state dictionary expected by LSTMWrapper + # state must be mutable as forward_eval updates it + state = {"lstm_h": h, "lstm_c": c} + + # Call forward_eval + logits, value = self.policy.forward_eval(observation, state) + + # Extract updated states + new_h = state["lstm_h"] + new_c = state["lstm_c"] + + return logits, value, new_h, new_c + + +def export_to_onnx(verify=True): + parser = argparse.ArgumentParser(description="Export PufferDrive model to ONNX") + parser.add_argument("--env", type=str, default="puffer_drive", help="Environment name") + parser.add_argument( + "--checkpoint", + type=str, + default="experiments/puffer_drive_73kbtsi5/model_puffer_drive_000200.pt", + help="Path to .pt checkpoint", + ) + parser.add_argument("--output", type=str, help="Output .onnx file path") + parser.add_argument("--opset", type=int, default=18, help="ONNX opset version") + + args = parser.parse_args() + + # Load configuration + config = load_config(args.env) + + # Load environment to get observation/action spaces + package = config["base"]["package"] + module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" + env_module = importlib.import_module(module_name) + make_env = env_module.env_creator(args.env) + + # Ensure env args/kwargs are correctly passed + env_kwargs = config["env"] + + vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1) + + # Initialize Policy + print("Initializing Policy...") + policy = Drive(vecenv.driver_env, **config["policy"]) + + if config["base"]["rnn_name"]: + print("Wrapping with LSTM...") + policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) + + # Load Checkpoint + print(f"Loading checkpoint from {args.checkpoint}...") + checkpoint = torch.load(args.checkpoint, map_location="cpu") + + # Handle both full checkpoint dict and raw state dict + if isinstance(checkpoint, dict) and "agent_state_dict" in checkpoint: + state_dict = checkpoint["agent_state_dict"] + else: + state_dict = checkpoint + + # Strip compile prefixes + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod."): + new_state_dict[k[10:]] = v + else: + new_state_dict[k] = v + + policy.load_state_dict(new_state_dict) + policy.eval() + + # Prepare inputs for ONNX export + print("Preparing sample inputs...") + batch_size = 1 + + obs_space = vecenv.single_observation_space + # Flatten observation if needed, Drive policy handles flattening internally usually but check vecenv + # The LSTMWrapper expects (B, ObsDim) + obs_dim = np.prod(obs_space.shape) + + # Create Dummy Observation + if config["base"]["rnn_name"]: + # If wrapped, access the internal Drive policy + drive_policy = policy.policy + else: + drive_policy = policy + + if hasattr(drive_policy, "ego_dim"): + # Construct valid dummy observation for Drive policy + # Retrieve needed dimensions + ego_dim = drive_policy.ego_dim + max_partner_objects = drive_policy.max_partner_objects + partner_features = drive_policy.partner_features + max_road_objects = drive_policy.max_road_objects + road_features = drive_policy.road_features + + partner_dim = max_partner_objects * partner_features + road_dim = max_road_objects * road_features + + # Random parts + dummy_ego = torch.randn(batch_size, ego_dim) + dummy_partner = torch.randn(batch_size, partner_dim) + + # Road part: continuous features + categorical feature + road_cont_dim = road_features - 1 + + # (Batch, MaxObjects, Feats-1) + dummy_road_cont = torch.randn(batch_size, max_road_objects, road_cont_dim) + + # (Batch, MaxObjects, 1) - valid categorical values [0, 6] + # Ensure it's 0-6 range. 7 is num_classes. + dummy_road_cat = torch.randint(0, 7, (batch_size, max_road_objects, 1)).float() + + # Concatenate and flatten + dummy_road_objs = torch.cat([dummy_road_cont, dummy_road_cat], dim=2) + dummy_road = dummy_road_objs.view(batch_size, -1) + + dummy_obs = torch.cat([dummy_ego, dummy_partner, dummy_road], dim=1) + else: + print("Warning: Could not determine Drive policy structure. Using random observation.") + dummy_obs = torch.randn(batch_size, obs_dim) + + # Dummy LSTM States + hidden_size = config["rnn"]["hidden_size"] + # LSTMCell expects (Batch, Hidden) not (NumLayers, Batch, Hidden) + dummy_h = torch.zeros(batch_size, hidden_size) + dummy_c = torch.zeros(batch_size, hidden_size) + + # Wrap policy for export + onnx_policy = OnnxWrapper(policy) + onnx_policy.eval() + + # Determine output path + if not args.output: + args.output = os.path.splitext(args.checkpoint)[0] + ".onnx" + # Ensure output directory exists + output_dir = os.path.dirname(args.output) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + print(f"Exporting to {args.output}...") + + # Dynamic axes for batch size flexibility + dynamic_axes = { + "observation": {0: "batch_size"}, + "lstm_h_in": {0: "batch_size"}, + "lstm_c_in": {0: "batch_size"}, + "logits": {0: "batch_size"}, + "value": {0: "batch_size"}, + "lstm_h_out": {0: "batch_size"}, + "lstm_c_out": {0: "batch_size"}, + } + + dummy_inputs = (dummy_obs, dummy_h, dummy_c) + torch.onnx.export( + onnx_policy, + dummy_inputs, + args.output, + export_params=True, + opset_version=args.opset, + do_constant_folding=True, + input_names=["observation", "lstm_h_in", "lstm_c_in"], + output_names=["logits", "value", "lstm_h_out", "lstm_c_out"], + dynamic_axes=dynamic_axes, + ) + + print("Export complete!") + print("\nSample Inputs shapes:") + print(f"Observation: {dummy_obs.shape}") + print(f"LSTM h: {dummy_h.shape}") + print(f"LSTM c: {dummy_c.shape}") + + # Verify ONNX model + if verify: + print("\nVerifying ONNX model...") + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = 1 + sess_options.inter_op_num_threads = 1 + ort_session = ort.InferenceSession(args.output, sess_options) + + # PyTorch output + with torch.no_grad(): + torch_logits, torch_value, torch_h, torch_c = onnx_policy(*dummy_inputs) + + # Output .pt files for testing + print(f"Saving test inputs/outputs to {output_dir}") + torch.save(dummy_inputs, os.path.join(output_dir, "test_inputs.pt")) + torch.save((torch_logits, torch_value, torch_h, torch_c), os.path.join(output_dir, "test_outputs.pt")) + + # ONNX Runtime output + ort_inputs = {"observation": dummy_obs.numpy(), "lstm_h_in": dummy_h.numpy(), "lstm_c_in": dummy_c.numpy()} + ort_outs = ort_session.run(None, ort_inputs) + + # Compare outputs + def compare(name, torch_out, ort_out, atol=1e-5): + if isinstance(torch_out, tuple): + for i, (t_out, o_out) in enumerate(zip(torch_out, ort_out)): + compare(f"{name}_{i}", t_out, o_out, atol) + return + + try: + np.testing.assert_allclose(torch_out.detach().numpy(), ort_out, rtol=1e-03, atol=atol) + print(f"✔ {name} match") + except AssertionError as e: + print(f"✘ {name} mismatch") + print(e) + + # Unpack ONNX outputs if logits was a tuple + if isinstance(torch_logits, tuple): + num_logits = len(torch_logits) + ort_logits = ort_outs[:num_logits] + ort_value = ort_outs[num_logits] + ort_h = ort_outs[num_logits + 1] + ort_c = ort_outs[num_logits + 2] + else: + ort_logits = ort_outs[0] + ort_value = ort_outs[1] + ort_h = ort_outs[2] + ort_c = ort_outs[3] + + compare("Logits", torch_logits, ort_logits) + compare("Value", torch_value, ort_value) + compare("LSTM h", torch_h, ort_h) + compare("LSTM c", torch_c, ort_c) + + # Export example input and output to .pt files + + +if __name__ == "__main__": + export_to_onnx(verify=True) diff --git a/scripts/verify_onnx.py b/scripts/verify_onnx.py new file mode 100644 index 000000000..27dc9ee9d --- /dev/null +++ b/scripts/verify_onnx.py @@ -0,0 +1,70 @@ +import argparse +import numpy as np +import onnx +import onnxruntime as ort + + +def verify_onnx(): + parser = argparse.ArgumentParser(description="Verify ONNX model") + parser.add_argument("--model", type=str, required=True, help="Path to .onnx file") + args = parser.parse_args() + + print(f"Checking model: {args.model}") + + # 1. Verify Model Structure with ONNX + try: + model = onnx.load(args.model) + onnx.checker.check_model(model) + print("ONNX Model Check: Passed (Structure is valid)") + except Exception as e: + print(f"ONNX Model Check: Failed\n{e}") + return + + # 2. Verify Execution with ONNX Runtime + try: + print("Starting ONNX Runtime session...") + ort_session = ort.InferenceSession(args.model) + + print("Session created successfully") + + # Prepare Dummy Inputs based on model expectation + inputs = {} + batch_size = 1 + + for input_meta in ort_session.get_inputs(): + name = input_meta.name + shape = input_meta.shape + dtype = input_meta.type + + # Handle dynamic axes (often represented as strings or -1) + processed_shape = [] + for dim in shape: + if isinstance(dim, str) or dim is None or dim < 0: + processed_shape.append(batch_size) + else: + processed_shape.append(dim) + + print(f" Input: {name}, Shape: {shape} -> Using: {processed_shape}, Type: {dtype}") + + # Create random input data + if "float" in dtype: + data = np.random.randn(*processed_shape).astype(np.float32) + else: + data = np.zeros(processed_shape).astype(np.int64) # Fallback + + inputs[name] = data + + # Run Inference + outputs = ort_session.run(None, inputs) + + print("Inference Verification: Passed") + print("\nOutput Shapes:") + for output_meta, output_data in zip(ort_session.get_outputs(), outputs): + print(f" Output: {output_meta.name}, Shape: {output_data.shape}") + + except Exception as e: + print(f"Inference Verification: Failed\n{e}") + + +if __name__ == "__main__": + verify_onnx()