Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 603efc2

Browse files
vkuzofacebook-github-bot
authored andcommitted
change x, w, dL_dY variable names to input, weight, grad_output (#323)
Summary: Pull Request resolved: #323 The following naming scheme matches the rest of PyTorch better: ```Python // forward output = input @ weight_t // backward grad_input = grad_output @ weight grad_weight = input_t @ grad_output ``` This PR changes all the previous references to `x`, `w`, `dL_dY` to match the naming scheme above. Reviewed By: drisspg Differential Revision: D60072596 fbshipit-source-id: 74e89d154a698a0dae8c92f39e2267409b151642
1 parent 9d5f892 commit 603efc2

19 files changed

+349
-305
lines changed

.github/workflows/ufmt.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,7 @@ jobs:
2323
pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1
2424
- name: Analyzing the code with ufmt
2525
run: |
26+
ufmt format .
27+
git diff
28+
git restore .
2629
ufmt check .

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ pip install -e ".[dev]"
2828

2929
# Single GPU User API
3030

31-
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
31+
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).
3232

33-
## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`
33+
## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`
3434

3535
This is the most accurate recipe as every tensor is scaled dynamically.
3636

@@ -95,9 +95,9 @@ m = Model(...)
9595
# type
9696
swap_linear_with_float8_linear(
9797
m,
98-
scaling_type_x=TensorScalingType.DELAYED,
99-
scaling_type_w=TensorScalingType.DELAYED,
100-
scaling_type_dL_dY=TensorScalingType.DELAYED,
98+
scaling_type_input=TensorScalingType.DELAYED,
99+
scaling_type_weight=TensorScalingType.DELAYED,
100+
scaling_type_grad_output=TensorScalingType.DELAYED,
101101
)
102102

103103
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and

benchmarks/bench_linear_float8.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,16 @@ def main(
9595
n_limit: Optional[int] = None,
9696
fast_accum_filter: Optional[bool] = None,
9797
shape_name_filter: Optional[str] = None,
98-
scaling_type_x: str = "dynamic",
99-
scaling_type_w: str = "dynamic",
100-
scaling_type_dL_dY: str = "dynamic",
98+
scaling_type_input: str = "dynamic",
99+
scaling_type_weight: str = "dynamic",
100+
scaling_type_grad_output: str = "dynamic",
101101
):
102102
device = "cuda"
103103
print(f"Compile is set to | {compile}")
104104

105-
scaling_type_x = TensorScalingType(scaling_type_x)
106-
scaling_type_w = TensorScalingType(scaling_type_w)
107-
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
105+
scaling_type_input = TensorScalingType(scaling_type_input)
106+
scaling_type_weight = TensorScalingType(scaling_type_weight)
107+
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
108108

109109
# LLaMa 2 70B single-node weight shapes
110110
# assumes fused attn.wqkv and ffn.w13
@@ -136,9 +136,9 @@ def main(
136136
linear_float8 = Float8Linear.from_float(
137137
copy.deepcopy(linear_ref),
138138
emulate=False,
139-
scaling_type_x=scaling_type_x,
140-
scaling_type_w=scaling_type_w,
141-
scaling_type_dL_dY=scaling_type_dL_dY,
139+
scaling_type_input=scaling_type_input,
140+
scaling_type_weight=scaling_type_weight,
141+
scaling_type_grad_output=scaling_type_grad_output,
142142
)
143143
scaling_repr = linear_float8.scaling_repr()
144144

@@ -153,7 +153,9 @@ def main(
153153
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
154154

155155
def float8_forw_backward():
156-
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
156+
if linear_requires_sync(
157+
scaling_type_input, scaling_type_weight, scaling_type_grad_output
158+
):
157159
sync_float8_amax_and_scale_history(linear_float8)
158160
linear_float8(input_tensor).sum().backward()
159161

@@ -278,18 +280,18 @@ def invoke_main() -> None:
278280
parser.add_argument("-n", "--n_limit", type=int, required=False)
279281
parser.add_argument("--fast_accum_filter", type=bool, required=False)
280282
parser.add_argument("--shape_name_filter", type=str, required=False)
281-
parser.add_argument("--scaling_type_x", type=str, required=False)
282-
parser.add_argument("--scaling_type_w", type=str, required=False)
283-
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
283+
parser.add_argument("--scaling_type_input", type=str, required=False)
284+
parser.add_argument("--scaling_type_weight", type=str, required=False)
285+
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
284286
args = parser.parse_args()
285287
output_path = Path(args.output_path) if args.output_path is not None else None
286288
kwargs = {}
287-
if args.scaling_type_x is not None:
288-
kwargs["scaling_type_x"] = args.scaling_type_x
289-
if args.scaling_type_w is not None:
290-
kwargs["scaling_type_w"] = args.scaling_type_w
291-
if args.scaling_type_dL_dY is not None:
292-
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
289+
if args.scaling_type_input is not None:
290+
kwargs["scaling_type_input"] = args.scaling_type_input
291+
if args.scaling_type_weight is not None:
292+
kwargs["scaling_type_weight"] = args.scaling_type_weight
293+
if args.scaling_type_grad_output is not None:
294+
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
293295
main(
294296
output_path,
295297
not args.disable_compile,

benchmarks/bench_multi_gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
6868
swap_linear_with_float8_linear(
6969
m,
7070
emulate=False,
71-
scaling_type_x=TensorScalingType.DELAYED,
72-
scaling_type_w=TensorScalingType.DELAYED,
73-
scaling_type_dL_dY=TensorScalingType.DELAYED,
71+
scaling_type_input=TensorScalingType.DELAYED,
72+
scaling_type_weight=TensorScalingType.DELAYED,
73+
scaling_type_grad_output=TensorScalingType.DELAYED,
7474
)
7575
return m
7676

benchmarks/profile_linear_float8.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,20 +204,23 @@ def profile_function(
204204
def main(
205205
profile_path_prefix: Path,
206206
compile: bool = True,
207-
scaling_type_x: str = "dynamic",
208-
scaling_type_w: str = "dynamic",
209-
scaling_type_dL_dY: str = "dynamic",
207+
scaling_type_input: str = "dynamic",
208+
scaling_type_weight: str = "dynamic",
209+
scaling_type_grad_output: str = "dynamic",
210210
model_type: str = "linear",
211211
dtype_filter: str = "both",
212212
):
213213
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
214214
assert dtype_filter in ("both", "float8", "bfloat16")
215215

216-
scaling_type_x = TensorScalingType(scaling_type_x)
217-
scaling_type_w = TensorScalingType(scaling_type_w)
218-
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
216+
scaling_type_input = TensorScalingType(scaling_type_input)
217+
scaling_type_weight = TensorScalingType(scaling_type_weight)
218+
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
219219
scaling_repr = "_".join(
220-
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
220+
[
221+
s.short_str()
222+
for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output)
223+
]
221224
)
222225

223226
print(f"Compile is set to | {compile}")
@@ -254,9 +257,9 @@ def main(
254257
m_ref = m_ref.to(device).to(ref_dtype)
255258

256259
extra_kwargs = {
257-
"scaling_type_x": scaling_type_x,
258-
"scaling_type_w": scaling_type_w,
259-
"scaling_type_dL_dY": scaling_type_dL_dY,
260+
"scaling_type_input": scaling_type_input,
261+
"scaling_type_weight": scaling_type_weight,
262+
"scaling_type_grad_output": scaling_type_grad_output,
260263
}
261264

262265
m_float8 = copy.deepcopy(m_ref)
@@ -278,7 +281,9 @@ def float8_forw_backward_wrapper(x):
278281
# inspection of the fw+bw torch.compile without the scale
279282
# syncing code
280283
# TODO(future): make this better
281-
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
284+
if linear_requires_sync(
285+
scaling_type_input, scaling_type_weight, scaling_type_grad_output
286+
):
282287
with record_function("scale_amax_and_scales"):
283288
sync_amax_history(m_float8)
284289
out = float8_forw(x)

0 commit comments

Comments
 (0)