You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Copy file name to clipboardExpand all lines: README.md
+18-15Lines changed: 18 additions & 15 deletions
Original file line number
Diff line number
Diff line change
@@ -2,11 +2,12 @@
2
2
3
3
This is an early version of a library for accelerating training with float8 in native PyTorch
4
4
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
5
-
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.
6
-
``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show
5
+
The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
6
+
and composable with key systems such as autograd, ```torch.compile``` and distributed.
7
+
With ``torch.compile`` on, initial results show
7
8
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
8
9
9
-
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet. </em>
10
+
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features.</em>
10
11
11
12
:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
12
13
will change rapidly.</em>
@@ -25,7 +26,7 @@ pip install -e .
25
26
pip install -e ".[dev]"
26
27
```
27
28
28
-
# User API
29
+
# Single GPU User API
29
30
30
31
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
31
32
@@ -113,30 +114,32 @@ for _ in range(N_ITER):
113
114
optimizer.step()
114
115
```
115
116
116
-
# 🧭 Code Organization
117
+
# Multi GPU User API
117
118
118
-
*`float8_experimental/float8_linear.py`
119
-
-`Float8Linear` (main user facing entry point for Float8Linear)
120
-
*`float8_experimental/float8_tensor.py`
121
-
-`Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
122
-
-`ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass
119
+
We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html),
120
+
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
121
+
on using `float8_experimental` in a distributed setting.
123
122
124
123
# Testing
125
124
126
125
```bash
127
126
# run single-GPU unit tests
128
127
pytest test/test_base.py
129
128
130
-
# run a single-GPU integration test on SAM
131
-
pytest test/test_sam.py
132
-
133
129
# run single-GPU compile tests
134
130
pytest test/test_compile.py
131
+
132
+
# run single-GPU numerics integration tests
133
+
pytest test/test_numerics_integration.py
134
+
135
135
# run a two-GPU integration test on FSDP
136
136
./test/test_fsdp.sh
137
137
138
-
# run integration tests for TP/SP (outdated)
139
-
./test/test_tp.sh
138
+
# run integration tests on the DTensor TP/SP integration
0 commit comments