Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit bc23bd2

Browse files
committed
bring back torch.autograd.Function for float8 matmul
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6cb1588 Pull Request resolved: #336
1 parent 8b6a015 commit bc23bd2

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

float8_experimental/float8_linear.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,62 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
7171
scale.copy_(new_scale)
7272

7373

74+
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
75+
@torch._dynamo.allow_in_graph
76+
class manual_float8_matmul(torch.autograd.Function):
77+
"""
78+
Like torch.matmul, but with the arguments in float8
79+
"""
80+
81+
@staticmethod
82+
def forward(
83+
ctx,
84+
input_fp8,
85+
weight_fp8_t,
86+
):
87+
ctx.save_for_backward(input_fp8, weight_fp8_t)
88+
# the reshapes are needed in order to make the shapes compatible with
89+
# torch.mm
90+
orig_shape = input_fp8.shape
91+
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
92+
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
93+
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
94+
return res_bits
95+
96+
@staticmethod
97+
def backward(ctx, grad_output_fp8):
98+
input_fp8, weight_fp8_t = ctx.saved_tensors
99+
100+
# the reshapes are needed in order to make the shapes compatible with
101+
# torch.mm
102+
grad_output_fp8_orig_shape = grad_output_fp8.shape
103+
grad_output_fp8_reshaped = grad_output_fp8.reshape(
104+
-1, grad_output_fp8_orig_shape[-1]
105+
)
106+
107+
# calculate grad_input
108+
grad_input = torch.mm(
109+
grad_output_fp8_reshaped,
110+
weight_fp8_t.t(),
111+
)
112+
grad_input = grad_input.reshape(
113+
*grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
114+
)
115+
116+
input_fp8_orig_shape = input_fp8.shape
117+
input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])
118+
119+
# calculate grad_weight
120+
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
121+
# compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
122+
grad_weight = torch.mm(
123+
grad_output_fp8_reshaped.t(),
124+
input_fp8_reshaped,
125+
)
126+
127+
return grad_input, grad_weight.t()
128+
129+
74130
@torch._dynamo.allow_in_graph
75131
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
76132
"""
@@ -393,7 +449,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
393449
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
394450
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)
395451

396-
output = torch.matmul(input_fp8, weight_fp8.t())
452+
output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
397453

398454
# Cast grad_output to float8_e5m2 during backward
399455
output = self.cast_output_to_float8_in_bw(output)

0 commit comments

Comments
 (0)