-
Notifications
You must be signed in to change notification settings - Fork 1
This PR impements pi0.5 as a HF style wrapper #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
bbe7973
Implement Pi0.5 upgrade: new architecture with flow matching and FAST…
De-funkd 5cda929
wip backup before starting PI05 HF wrapper
De-funkd 96084f6
final commit
De-funkd a7757bf
removed the init file from root
De-funkd 2e47a85
fixed comments
De-funkd 71bc7da
Resolve merge conflict: reintegrate pi05 registry entry
De-funkd 13f65fa
removed redundant test files
De-funkd 1358953
integration fixes for pi05
Refinath 97c49f7
Fix Pi0.5 contract mismatches to align with Ark training and rollout …
De-funkd bd86766
Resolve merge conflict in registry.py by including both pi05 and Pi05…
De-funkd b504172
fixed rollout issues
De-funkd 817f963
fixes to lang tokens
De-funkd c684eae
fixes to training and rollouts
De-funkd e00c4a3
implemented fixes
De-funkd 0c65b93
more fixes
De-funkd d3771f0
pr fixes
Refinath a831e27
pr issue fixes
Refinath a6f0575
dataset fixes
De-funkd 4554b6f
pi05 dataset updated based on existing structure
Refinath d1ed44d
toekns and attension mask for lerobot
Refinath 1c6e4f6
PR fixes, roll out and training
Refinath File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| # Pi0.5 Implementation | ||
|
|
||
| This directory contains the complete Pi0.5 implementation following the HuggingFace wrapper pattern for the Ark ML framework. | ||
|
|
||
| ## Architecture Overview | ||
|
|
||
| Pi0.5 is an advanced Vision-Language-Action model that implements: | ||
| - **Multi-stage training**: Pretraining (CE(text) + CE(FAST tokens)) and Post-training (CE(subtask) + α × flow_matching_loss) | ||
| - **Flow matching**: For precise action prediction using vector field networks | ||
| - **Multiple prediction heads**: Subtask, FAST, and flow heads | ||
| - **Enhanced backbone**: Support for SigLIP-Gemma vision-language architecture | ||
|
|
||
| ## Directory Structure | ||
|
|
||
| ``` | ||
| pi05/ | ||
| ├── models.py # Core Pi0.5 policy (HuggingFace wrapper) | ||
| ├── algorithm.py # Training algorithm | ||
| ├── trainer.py # Multi-stage trainer | ||
| ├── evaluator.py # Evaluation metrics | ||
| ├── dataset.py # Multi-modality dataset | ||
| ├── config_utils.py # Configuration utilities | ||
| ├── compute_stats.py # Statistics computation | ||
| ├── utils.py # Utility functions | ||
| └── README.md # This file | ||
| ``` | ||
|
|
||
| ## Usage Instructions | ||
|
|
||
| ### 1. Loading a Pre-trained Model | ||
|
|
||
| ```python | ||
| from arkml.algos.vla.pi05.models import Pi05Policy | ||
|
|
||
| # Load from Hugging Face Hub or local path | ||
| policy = Pi05Policy( | ||
| policy_type='pi0.5', | ||
| model_path='your-huggingface-username/pi05-model', # or local path | ||
| backbone_type='siglip_gemma', # Vision-language backbone | ||
| use_fast_tokens=True, # Enable FAST tokenization | ||
| use_flow_matching=True, # Enable flow matching | ||
| obs_dim=9, # Observation dimension | ||
| action_dim=8, # Action dimension | ||
| image_dim=(3, 480, 640), # Image dimensions (C, H, W) | ||
| pred_horizon=1 # Prediction horizon | ||
| ) | ||
|
|
||
| # Move to device | ||
| policy = policy.to_device('cuda') | ||
| ``` | ||
|
|
||
| ### 2. Making Predictions | ||
|
|
||
| ```python | ||
| import torch | ||
|
|
||
| # Prepare observation dictionary | ||
| observation = { | ||
| 'image': torch.randn(1, 3, 224, 224), # Image tensor | ||
| 'state': torch.randn(9), # State vector | ||
| 'task': 'pick up the red block' # Task instruction (optional) | ||
| } | ||
|
|
||
| # Get action prediction | ||
| action = policy.predict(observation) | ||
| print(f"Predicted action: {action}") | ||
| ``` | ||
|
|
||
| ### 3. Training a New Model | ||
|
|
||
| ```python | ||
| from arkml.algos.vla.pi05.algorithm import Pi05Algorithm | ||
| from arkml.algos.vla.pi05.dataset import create_pi05_dataloader | ||
| from omegaconf import DictConfig | ||
|
|
||
| # Create your dataset and dataloader | ||
| train_dataloader = create_pi05_dataloader( | ||
| dataset_path='path/to/your/dataset', | ||
| batch_size=8, | ||
| shuffle=True | ||
| ) | ||
|
|
||
| # Load your policy | ||
| policy = Pi05Policy( | ||
| policy_type='pi0.5', | ||
| model_path='path/to/pretrained/model', # Or use a base model | ||
| # ... other parameters | ||
| ) | ||
|
|
||
| # Configure training | ||
| config = DictConfig({ | ||
| 'trainer': { | ||
| 'lr': 2e-4, | ||
| 'batch_size': 8, | ||
| 'max_epochs': 10, | ||
| 'weight_decay': 0.01, | ||
| 'num_workers': 4, | ||
| 'use_bf16': True | ||
| }, | ||
| 'training': { | ||
| 'stage': 'pretrain', # 'pretrain' or 'posttrain' | ||
| 'flow_alpha': 10.0, # Weight for flow matching loss | ||
| 'pretrain_steps': 280000, # Steps for pretraining | ||
| 'posttrain_steps': 80000 # Steps for post-training | ||
| } | ||
| }) | ||
|
|
||
| # Create algorithm and train | ||
| algorithm = Pi05Algorithm(policy=policy, device='cuda', cfg=config) | ||
| results = algorithm.train(train_dataset=your_train_dataset) | ||
| ``` | ||
|
|
||
| ### 4. Configuration Options | ||
|
|
||
| Key configuration parameters: | ||
|
|
||
| - `backbone_type`: Vision-language backbone ('siglip_gemma', etc.) | ||
| - `use_fast_tokens`: Whether to use FAST tokenization for action discretization | ||
| - `use_flow_matching`: Whether to use flow matching for action prediction | ||
| - `training_stage`: 'pretrain' or 'posttrain' for multi-stage training | ||
| - `flow_alpha`: Weight for flow matching loss (default: 10.0) | ||
|
|
||
| ## Training Stages | ||
|
|
||
| Pi0.5 supports multi-stage training: | ||
|
|
||
| ### Pretraining Stage | ||
| ``` | ||
| CE(text) + CE(FAST tokens) | ||
| ``` | ||
| - Focuses on learning foundational representations | ||
| - Uses multiple modalities and FAST tokenization | ||
|
|
||
| ### Post-training Stage | ||
| ``` | ||
| CE(subtask) + α × flow_matching_loss | ||
| ``` | ||
| - Refines the model with flow matching and subtask prediction | ||
| - Enables precise action prediction using flow matching | ||
|
|
||
| ## Evaluation Metrics | ||
|
|
||
| The evaluator provides comprehensive metrics: | ||
| - Action MSE and MAE | ||
| - Accuracy within threshold | ||
| - Subtask prediction accuracy | ||
| - Multi-modality evaluation | ||
|
|
||
| ## Integration with LeRobot | ||
|
|
||
| This implementation uses the LeRobot Pi0.5 policy under the hood: | ||
| - Follows LeRobot's model architecture | ||
| - Compatible with LeRobot datasets and tools | ||
| - Supports LeRobot's training and evaluation pipelines | ||
|
|
||
| ## Example Usage Script | ||
|
|
||
| For a complete example, see the example script that demonstrates: | ||
| - Model loading | ||
| - Training setup | ||
| - Prediction workflow | ||
| - Evaluation process | ||
|
|
||
| ## Requirements | ||
|
|
||
| - LeRobot >= 0.4.3 | ||
| - Transformers | ||
| - PyTorch >= 1.12 | ||
| - Compatible with ark_ml framework | ||
|
|
||
| ## Testing | ||
|
|
||
| Run tests to verify functionality: | ||
| ```bash | ||
| python -m pytest tests_and_benchmarks/pi05_tests/ | ||
| ``` | ||
|
|
||
| ## Benchmarks | ||
|
|
||
| Run performance benchmarks: | ||
| ```bash | ||
| python tests_and_benchmarks/pi05_benchmarks/benchmark_pi05.py | ||
| ``` | ||
|
|
||
| ## Notes | ||
|
|
||
| - This implementation follows the same pattern as PiZero for consistency | ||
| - Multi-stage training requires different dataset configurations for each stage | ||
| - Flow matching is particularly effective for precise manipulation tasks | ||
| - FAST tokenization enables efficient action discretization during pretraining |
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| from typing import Any | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
| from arkml.core.algorithm import BaseAlgorithm | ||
| from arkml.core.policy import BasePolicy | ||
| from arkml.core.registry import ALGOS | ||
| from arkml.algos.vla.pi05.trainer import Pi05Trainer | ||
| from arkml.algos.vla.pi05.evaluator import Pi05Evaluator | ||
| from omegaconf import DictConfig | ||
|
|
||
| @ALGOS.register("pi05") | ||
| class Pi05Algorithm(BaseAlgorithm): | ||
| """ | ||
| Algorithm wrapper for Pi0.5 training and evaluation. | ||
| Implements the complete training pipeline for Pi0.5 with multi-stage training. | ||
| """ | ||
|
|
||
| def __init__(self, policy: BasePolicy, device: str, cfg: DictConfig) -> None: | ||
| self.policy = policy | ||
| self.device = device | ||
| self.cfg = cfg | ||
|
|
||
| # Extract training configuration | ||
| self.lr = cfg.trainer.get('lr', 2e-4) | ||
| self.batch_size = cfg.trainer.get('batch_size', 8) | ||
| self.max_epochs = cfg.trainer.get('max_epochs', 10) | ||
| self.weight_decay = cfg.trainer.get('weight_decay', 0.0) | ||
| self.num_workers = cfg.trainer.get('num_workers', 4) | ||
| self.use_bf16 = cfg.trainer.get('use_bf16', True) | ||
|
|
||
| # Training-specific config | ||
| self.training_stage = cfg.training.get('stage', 'pretrain') | ||
| self.flow_alpha = cfg.training.get('flow_alpha', 10.0) | ||
| self.pretrain_steps = cfg.training.get('pretrain_steps', 280000) | ||
| self.posttrain_steps = cfg.training.get('posttrain_steps', 80000) | ||
| self.integration_steps = cfg.training.get('integration_steps', 10) | ||
|
|
||
| def train(self, train_dataset, val_dataset=None) -> Any: | ||
| """ | ||
| Train the Pi0.5 model with multi-stage approach. | ||
| """ | ||
| # Create data loaders | ||
| train_dataloader = torch.utils.data.DataLoader( | ||
| train_dataset, | ||
| batch_size=self.batch_size, | ||
| shuffle=True, | ||
| num_workers=self.num_workers, | ||
| pin_memory=True | ||
| ) | ||
|
|
||
| val_dataloader = None | ||
| if val_dataset: | ||
| val_dataloader = torch.utils.data.DataLoader( | ||
| val_dataset, | ||
| batch_size=self.batch_size, | ||
| shuffle=False, | ||
| num_workers=self.num_workers, | ||
| pin_memory=True | ||
| ) | ||
|
|
||
| # Initialize trainer with config | ||
| trainer = Pi05Trainer( | ||
| model=self.policy, | ||
| dataloader=train_dataloader, | ||
| device=self.device, | ||
| lr=self.lr, | ||
| weight_decay=self.weight_decay, | ||
| num_epochs=self.max_epochs, | ||
| grad_accum=1.0, # Gradient accumulation | ||
| output_dir='./output', # TODO: Get from config | ||
| use_bf16=self.use_bf16, | ||
| flow_alpha=self.flow_alpha, | ||
| val_dataloader=val_dataloader, | ||
| eval_every=1 | ||
| ) | ||
|
|
||
| # Set the training stage on the model | ||
| self.policy.training_stage = self.training_stage | ||
|
|
||
| # Perform training based on stage | ||
| return trainer.fit() | ||
|
|
||
| def eval(self, eval_dataset) -> dict: | ||
| """ | ||
| Evaluate the Pi0.5 model performance. | ||
| """ | ||
| eval_dataloader = torch.utils.data.DataLoader( | ||
| eval_dataset, | ||
| batch_size=self.batch_size, | ||
| shuffle=False, | ||
| num_workers=self.num_workers, | ||
| pin_memory=True | ||
| ) | ||
|
|
||
| # Initialize evaluator | ||
| evaluator = Pi05Evaluator( | ||
| model=self.policy, | ||
| dataloader=eval_dataloader, | ||
| device=self.device | ||
| ) | ||
|
|
||
| # Perform evaluation | ||
| return evaluator.evaluate() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.