Skip to content

float8 training axiswise scaling support with per-gemm-argument configuration #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Oct 7, 2024

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Sep 24, 2024

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp

Key characteristics of this recipe:

  1. increased accuracy for grad_weight, which is important for real workloads
  2. output and weight now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)

performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

gemm performance of torch._scaled_mm

baseline: tensorwise scaling

> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                

experiment: axiswise-scaling

> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

  • baseline (bf16 + compile): 6,294 wps
  • f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
  • f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
  • LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with LW_AXISWISE_WITH_GW_HP in future PRs

accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work.

Screenshot 2024-10-04 at 10 05 24 PM

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 24, 2024

Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/940

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b536435 with merge base e76db70 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 24, 2024
vkuzo added a commit that referenced this pull request Sep 24, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: de754d2
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 25, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: bf83d2e
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 25, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0cfb3bb
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3eaa2df
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c13c0ee
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3dcb57f
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4b97519
ghstack-comment-id: 2372563439
Pull Request resolved: #940
@vkuzo vkuzo changed the title [wip] make scaling configurable by gemm-argument make float8 scaling configurable by gemm-argument Sep 27, 2024
vkuzo added 3 commits October 4, 2024 09:50
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 71d847f
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ce6eb7d
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 816eaa1
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 74601dd
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a565ff4
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c896175
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: abb4fcb
ghstack-comment-id: 2372563439
Pull Request resolved: #940
@vkuzo vkuzo changed the title make float8 scaling configurable by gemm-argument official axiswise scaling support with per-gemm-argument configuration Oct 5, 2024
@vkuzo vkuzo changed the title official axiswise scaling support with per-gemm-argument configuration float8 training axiswise scaling support with per-gemm-argument configuration Oct 5, 2024
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 5, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26b3b8f
ghstack-comment-id: 2372563439
Pull Request resolved: #940
vkuzo added 2 commits October 7, 2024 10:21
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 7, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ba2f870
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/11/head to main October 7, 2024 20:59
vkuzo added a commit that referenced this pull request Oct 7, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c461f25
ghstack-comment-id: 2372563439
Pull Request resolved: #940
@vkuzo vkuzo merged commit dec0313 into main Oct 7, 2024
43 checks passed
jainapurva pushed a commit that referenced this pull request Oct 9, 2024
…guration (#940)

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet.  Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

```
output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp
```

Key characteristics of this recipe:
1. increased accuracy for `grad_weight`, which is important for real workloads
2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

```python
#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)
```

# performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

## gemm performance of torch._scaled_mm

baseline: tensorwise scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                
```

experiment: axiswise-scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

```

## performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe.  For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

```
> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

```

## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

* baseline (bf16 + compile): 6,294 wps
* f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
* f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
* LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs

# accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations.  I will leave longer accuracy verifications for future work.

<img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c">


Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
…guration (#940)

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet.  Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

```
output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp
```

Key characteristics of this recipe:
1. increased accuracy for `grad_weight`, which is important for real workloads
2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

```python
#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)
```

# performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

## gemm performance of torch._scaled_mm

baseline: tensorwise scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                
```

experiment: axiswise-scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

```

## performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe.  For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

```
> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

```

## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

* baseline (bf16 + compile): 6,294 wps
* f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
* f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
* LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs

# accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations.  I will leave longer accuracy verifications for future work.

<img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c">


Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Update Android instrucitons on README.md

Specify the steps to create `android/Torchchat/app/libs/`, and to rename the AAR file.

* Update README.md

* Update README.md for moving file

* Update Android SDK instructions

* Update README.md

* Add adb requirement

* Update README.md

* Fix formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants