Skip to content

Commit c154866

Browse files
committed
Init commit
1 parent 4b8fdfc commit c154866

File tree

136 files changed

+43645
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+43645
-0
lines changed

.gitignore

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
#.idea/
161+
162+
# Some encoder paths
163+
facebook/
164+
openai/
165+
google/
166+
167+
# Log
168+
logs/
169+
170+
# Output
171+
outs/
172+
173+
# Checkpoints
174+
checkpoints/
175+
176+
# VSC
177+
.vscode/
178+
179+
# Wandb
180+
wandb/
181+
182+
# Distributed leaning
183+
hostfile.txt
184+
.deepspeed_env

README.md

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation
2+
3+
### 📝[Paper]() | 🌍[Project Page](https://rdt-robotics.github.io/rdt-robotics/) | 🤗[Model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) | 🛢️[Data](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)
4+
5+
![](./assets/head.png)
6+
7+
RDT-1B is a **1B**-parameter (*largest* to date) imitation learning **Diffusion Transformer** pre-trained on **1M+** (*largest* to date) multi-robot episodes. Given language instruction and RGB images of up to three views, RDT can predict the next $64$ robot actions. RDT is inherently compatible with **almost all kinds of modern mobile manipulators**, from single-arm to dual-arm, joint to EEF, position to velocity, and even with wheeled locomotion.
8+
9+
We have fine-tuned RDT on **6K+** (one of the *largest*) self-collected bimanual episodes and deployed it on the ALOHA **dual-arm** robot. It has achieved state-of-the-art performance in terms of dexterity, zero-shot generalizability, and few-shot learning. You can find Demo videos on our [project page](https://rdt-robotics.github.io/rdt-robotics/).
10+
11+
This repo is an official PyTorch implementation of RDT, containing:
12+
13+
- 🛠️Model [implementation](models/rdt_runner.py) of RDT
14+
- 🤗1M-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) of the pre-trained RDT-1B
15+
- 📈Training and sampling [scripts](train/train.py) (with DeepSpeed)
16+
- 🤖An [example](scripts/agilex_inference.py) of real-robot deployment
17+
18+
The following guides include the [installation](#installation), [fine-tuning](#fine-tuning-on-your-own-dataset), and [deployment](#deployment-on-real-robots). Please refer to [pre-training](docs/pretrain.md) for a detailed list of pre-training datasets and a pre-training guide.
19+
20+
## Installation
21+
22+
1. Clone this repo and install prerequisites:
23+
24+
```bash
25+
# Clone this repo
26+
git clone [email protected]:thu-ml/RoboticsDiffusionTransformer.git
27+
cd RoboticsDiffusionTransformer
28+
29+
# Create a Conda environment
30+
conda create -n rdt python=3.10.0
31+
conda activate rdt
32+
33+
# Install pytorch
34+
# Look up https://pytorch.org/get-started/previous-versions/ with your cuda version for a correct command
35+
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
36+
37+
# Install packaging
38+
pip install packaging==24.0
39+
40+
# Install flash-attn
41+
pip install flash-attn --no-build-isolation
42+
43+
# Install other prequisites
44+
pip install -r requirements.txt
45+
```
46+
47+
2. Download off-the-shelf multi-modal encoders:
48+
49+
You can download the encoders from the following links:
50+
51+
- `t5-v1_1-xxl`: [link](https://huggingface.co/google/t5-v1_1-xxl/tree/main)🤗
52+
- `siglip`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)🤗
53+
54+
And link the encoders to the repo directory:
55+
56+
```bash
57+
# Under the root directory of this repo
58+
mkdir -p google
59+
60+
# Link the downloaded encoders to this repo
61+
ln -s /path/to/t5-v1_1-xxl google/t5-v1_1-xxl
62+
ln -s /path/to/siglip-so400m-patch14-384 google/siglip-so400m-patch14-384
63+
```
64+
65+
## Fine-Tuning on Your Own Dataset
66+
67+
If your fine-tuning dataset is in the [Open X-Embodiment](https://robotics-transformer-x.github.io/) or the collection of our pre-training datasets (see [this doc](docs/pretrain.md)), you can also fine-tune RDT through the pre-trained pipeline. You just need to remove other redundant datasets in the parameters. We refer to [this guide](docs/pretrain.md) (pre-training).
68+
69+
1. Prepare your dataset:
70+
71+
You need to download your dataset to the disk and give it a name `my_cool_dataset`.
72+
73+
Then, you can link your dataset to the repo directory:
74+
75+
```bash
76+
# Under the root directory of this repo
77+
cd data
78+
mkdir -p datasets
79+
80+
# Link the downloaded dataset to this repo
81+
ln -s /path/to/my_cool_dataset datasets/my_cool_dataset
82+
```
83+
84+
2. Implement the dataset loader:
85+
86+
You need to:
87+
88+
1. Register the configuration of `my_cool_dataset`:
89+
90+
Append the control frequency of `my_cool_dataset` in [this file](configs/dataset_control_freq.json). Write the name of `my_cool_dataset` in [this file](configs/finetune_datasets.json) and [this file](configs/finetune_sample_weights.json), where the value of the sampling weight doesn't matter since you only have one dataset. In this two files, we leave a placeholder of `agilex`, you can simply replace it with `my_cool_dataset`.
91+
92+
2. Re-Implement the class of `HDF5VLADataset`:
93+
94+
You can find this class in [this file](data/hdf5_vla_dataset.py). In this file, we provide an example of loading the fine-tuning dataset used in our paper (see [this link](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)).
95+
96+
To adapt it to your dataset, you need to: (a) modify the `HDF5_DIR` (directory to `my_cool_dataset`) and `DATASET_NAME` (should be `"my_cool_dataset"`) in L21 and L22; (b) Implement the two functions of `parse_hdf5_file()` and `parse_hdf5_file_state_only()`. Refer to the original file for detailed comments and examples.
97+
98+
Note 1: Despite its name, you don't necessarily need to use HDF5 to store you data. Just make sure that the class is correctly implemented.
99+
100+
Note 2: During implementation, you may need fill your robot action into the unified action vector (L180-194). Please refer to [this file](configs/state_vec.py) for an explanation for each element in the unified vector.
101+
102+
**IMPORTANT 1:** If your robot is single-arm, please fill its action into the *right-arm* portion of the unified action vector, aligning with our pre-training datasets.
103+
104+
**IMPORTANT 2:** We use [6D representation](https://arxiv.org/pdf/1812.07035) for EEF rotation. If your action space contains EEF rotation (angle or quaternion), please refer to [this file](docs/test_6drot.py) for conversion. We note that this mapping is not reversible. Different Euler angles may be equivalent and correspond to the same 6D representation.
105+
106+
**IMPORTANT 3:** During pre-training, no physical quantities (except the gripper width) are normalized. We believe that this can preserve the physical meaning of each physical quantity, thereby promoting generalization across robots. Therefore, we encourage you not to normalize any physical quantities, but to choose appropriate units for them. Generally, we use the International System of Units, which ensures that most values fall within [-1,1]. As an exception, we perform min-max normalization on the gripper width to [0,1].
107+
108+
**IMPORTANT 4:** If you are using RTX 4090 (or lower), the GPU memory may be to low to load the `t5-v1_1-xxl` encoder. Instead, we recommend you to precompute the language embeddings (see [this file](scripts/encode_lang_batch.py) for an example script) and to load them during training. In this way, you need to specify the path to the embeddings in the `HDF5VLADataset` (see L148) rather than the natural language.
109+
110+
3. Compute the dataset statistics information for `my_cool_dataset`:
111+
112+
```bash
113+
# Under the root directory of this repo
114+
# Use -h to see the full usage
115+
python -m data.compute_dataset_stat --dataset_type="finetune" --hdf5_dataset
116+
```
117+
118+
3. Start fine-tuning:
119+
120+
Configurations relevant to model architecture and data processing are in [this file](configs/base.yaml). Normally you do not need to modify these configurations, otherwise it will cause errors in loading the pre-training checkpoint. Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example fine-tuning script in [this file](finetune.sh) (`finetune.sh`). You may need to modify some of the parameters in this file, such as `OUTPUT_DIR`, `CUTLASS_PATH`, and `WANDB_PROJECT`.
121+
122+
Use this to start fine-tuning:
123+
124+
```bash
125+
source finetune.sh
126+
```
127+
128+
with `finetune.sh` detailed as below:
129+
130+
```bash
131+
deepspeed --hostfile=hostfile.txt main.py \
132+
--deepspeed="./configs/zero2.json" \ # If you want to use DeepSpeed, which is strongly recommended
133+
--pretrained_model_name_or_path=<MODEL ID | DIRECTORY OF MODEL WEIGHTS | PATH TO MODEL CHECKPOINT> \
134+
--pretrained_text_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY > \ # e.g., google/t5-v1_1-xxl
135+
--pretrained_vision_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY> \ # e.g., google/siglip-so400m-patch14-384
136+
--output_dir=<DIRECTORY to SAVE CHECKPOINTS> \ # e.g., checkpoints/rdt-1b-agilex
137+
--train_batch_size=32 \
138+
--sample_batch_size=64 \ # batch size for diffusion sampling in validation
139+
--max_train_steps=200000 \
140+
--checkpointing_period=1000 \
141+
--sample_period=500 \ # sample period for validation
142+
--checkpoints_total_limit=40 \
143+
--lr_scheduler="constant" \
144+
--learning_rate=1e-4 \
145+
--mixed_precision="bf16" \ # If you want to use mixed precision, bf16 is recommended
146+
--dataloader_num_workers=8 \
147+
--image_aug \ # If you want to use image augmentation
148+
--dataset_type="finetune" \
149+
--state_noise_snr=40 \ # If you want to add noise to the state
150+
--load_from_hdf5 \ # If you use HDF5 to store your data
151+
--report_to=wandb
152+
```
153+
154+
**IMPORTANT**: If you have already chosen to precompute the language embeddings, please specify `--precomp_lang_embed` in the `finetune.sh`.
155+
156+
Note 1: `pretrained_model_name_or_path` can one of:
157+
158+
- a string, the *model id* of a pre-trained model hosted inside a model repo on HuggingFace. Please fill with `"robotics-diffusion-transformer/rdt-1b"`, which is the officially-released [RDT-1B model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗 at HuggingFace. (recommended)
159+
- a string, the path to a *directory* containing the manually downloaded model weights from HuggingFace, e.g., `"/path/to/rdt-1b"`. You should first manually download the `rdt-1b` directory from this [link](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗.
160+
- a string, the path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method. This can be either:
161+
- `"checkpoints/rdt-1b-pretrain/checkpoint-<STEP NUMBER>"` : This is the path to the checkpoint saved in the `<STEP NUMBE>` iteration during pre-training. Refer to [this file](docs/pretrain.md) for a tutorial on how to start your own pre-training.
162+
- `"checkpoints/rdt-1b-pretrain"` : If the pre-training completes normally without any program failure, you can specify this path to load the last checkpoint.
163+
- a string, the path to model checkpoint (`*.pt`) saved by DeepSpeed, e.g., `"checkpoints/rdt-1b-pretrain/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt"` (verified)
164+
- `None` if you want to randomly initialise the model using configuration at `config_path`.
165+
166+
Note 2: You can monitor the training process by observing `loss` (through a long window moving average) and `overall_avg_sample_mse` in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs. Usually, fine-tuning is over when this value converges.
167+
168+
Note 3: If the training is oscillating, you can increase the batch size by adding more GPUs or setting a larger `--gradient_accumulation_steps`.
169+
170+
## Deployment on Real-Robots
171+
172+
We have encapsulated the inference of the model into a class named `RoboticDiffusionTransformerModel` (see L34 in [this file](scripts/agilex_model.py)). You can call `step()` method of this class for inference. However, you may need to re-implement some parts of it according to your specific model. You should at least modify the `_format_joint_to_state()` (L154) and `_unformat_action_to_joint()` (L186) to convert between robot raw actions and unified action vectors that RDT accepts. You may also specify the control frequency of your robot (L45).
173+
174+
**IMPORTANT**: When you feed the images into `step()`, remember the order MUST be `[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, ext_{t}, right_wrist_{t}, left_wrist_{t}]`.
175+
176+
We provide an example hardware code in [this file](scripts/agilex_inference.py) for deployment on Mobile ALOHA, and the corresponding running script in [this file](inference.sh) (`inference.sh`) which is detailed as below;
177+
178+
```bash
179+
python -m scripts.agilex_inference \
180+
--use_actions_interpolation \
181+
--pretrained_model_name_or_path=<MODEL ID | DIRECTORY OF MODEL WEIGHTS | PATH TO MODEL CHECKPOINT> \ # same as argument for fine-tuning e.g., checkpoints/your_finetuned_ckpt.pt or checkpoints/your_finetuned_ckpt
182+
--lang_embeddings_path=<PATH TO YOUR INSTURCTION EMBEDDINGS> \ # e.g. outs/lang_embeddings/your_instr.pt"
183+
--ctrl_freq=25 # your control frequency
184+
```
185+
186+
**IMPORTANT**: If you on-board GPU memory is not enough to encode the language, please refer to [this file](scripts/encode_lang.py) for precomputation and specify the language embedding path in `inference.sh`.
187+
188+
Note: If you want to deploy on Mobile ALOHA robot, don't forget to install prerequisites of hardware (see [this repo](https://github.com/MarkFzp/mobile-aloha)).
189+
190+
## Citation
191+
192+
If you find our work useful, please cite us:
193+
194+
```bibtex
195+
196+
```
197+
198+
Thank you!
199+
200+
## License
201+
202+
All the code, model weights, and data are licensed under [MIT license](./LICENSE).

assets/head.png

726 KB
Loading

0 commit comments

Comments
 (0)