@@ -177,12 +177,12 @@ def _initialize_scale_zero_point(
177
177
tensor_amax = torch .abs (module .weight .data ).max ().to (torch .float32 )
178
178
# Setting data for now - could possibly be handled later in the pipeline
179
179
value = FP8_E4M3_DATA .max * FP4_E2M1_DATA .max / tensor_amax
180
- # TODO: use model.weight.dtype
180
+ # TODO: use model.weight.dtype after checking
181
181
value = value .to (torch .float32 ).to (device )
182
182
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
183
183
init_global_scale = Parameter (value , requires_grad = False )
184
184
register_offload_parameter (
185
- module , f"f { base_name } _global_scale" , init_global_scale
185
+ module , f"{ base_name } _global_scale" , init_global_scale
186
186
)
187
187
188
188
if scale_dtype not in [
@@ -201,7 +201,14 @@ def _initialize_scale_zero_point(
201
201
register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
202
202
203
203
if force_zero_point or not quantization_args .symmetric :
204
- zp_dtype = quantization_args .pytorch_dtype ()
204
+ if (
205
+ quantization_args .num_bits == 4
206
+ and quantization_args .type == QuantizationType .FLOAT
207
+ ):
208
+ zp_dtype = FP8_E4M3_DATA .dtype
209
+ else :
210
+ zp_dtype = quantization_args .pytorch_dtype ()
211
+
205
212
init_zero_point = Parameter (
206
213
torch .zeros (expected_shape , device = device , dtype = zp_dtype ),
207
214
requires_grad = False ,
0 commit comments