Skip to content

Commit eec7bd3

Browse files
committed
update
1 parent 1727508 commit eec7bd3

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ def _initialize_scale_zero_point(
177177
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
178178
# Setting data for now - could possibly be handled later in the pipeline
179179
value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax
180-
# TODO: use model.weight.dtype
180+
# TODO: use model.weight.dtype after checking
181181
value = value.to(torch.float32).to(device)
182182
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
183183
init_global_scale = Parameter(value, requires_grad=False)
184184
register_offload_parameter(
185-
module, f"f{base_name}_global_scale", init_global_scale
185+
module, f"{base_name}_global_scale", init_global_scale
186186
)
187187

188188
if scale_dtype not in [
@@ -201,7 +201,14 @@ def _initialize_scale_zero_point(
201201
register_offload_parameter(module, f"{base_name}_scale", init_scale)
202202

203203
if force_zero_point or not quantization_args.symmetric:
204-
zp_dtype = quantization_args.pytorch_dtype()
204+
if (
205+
quantization_args.num_bits == 4
206+
and quantization_args.type == QuantizationType.FLOAT
207+
):
208+
zp_dtype = FP8_E4M3_DATA.dtype
209+
else:
210+
zp_dtype = quantization_args.pytorch_dtype()
211+
205212
init_zero_point = Parameter(
206213
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
207214
requires_grad=False,

src/compressed_tensors/quantization/utils/helpers.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def calculate_qparams(
7474

7575
bit_min, bit_max = calculate_range(quantization_args, device)
7676
bit_range = bit_max - bit_min
77-
zp_dtype = quantization_args.pytorch_dtype()
77+
# TODO: update
78+
# zp_dtype = quantization_args.pytorch_dtype()
79+
zp_dtype = FP8_E4M3_DATA.dtype
7880

7981
if quantization_args.symmetric:
8082
# TODO: update for NVFP4 when applying observers
@@ -85,15 +87,18 @@ def calculate_qparams(
8587
and quantization_args.type == QuantizationType.FLOAT
8688
):
8789
assert global_scale is not None
88-
scale = max_val_pos / FP4_E2M1_DATA.max # Not needed
89-
scale = scale / global_scale
90-
scale = scale.to(FP8_E4M3_DATA.dtype) # .to(torch.float32)
90+
breakpoint()
91+
scales = max_val_pos / FP4_E2M1_DATA.max # Not needed
92+
scales = scales / global_scale
93+
scales = scales.to(FP8_E4M3_DATA.dtype) # .to(torch.float32)
9194

9295
else:
9396
# Divide over bit range over max value?
9497
scales = max_val_pos / (float(bit_range) / 2)
9598

96-
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
99+
# TODO: clamp not implemented for FP8 '
100+
breakpoint()
101+
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
97102
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
98103
else:
99104
scales = (max_vals - min_vals) / float(bit_range)

0 commit comments

Comments
 (0)