Skip to content

Commit 8c8bc81

Browse files
authored
Move quant ops to utils.py (#331)
Summary: We had a lot of "quant primitive" ops that can be expressed with more primitive ops, so these ops are more of a helper functions now, so we moved them to torchao.quantization.utils we should be able to further deprecate some of the ops after we deprecate subclasses and refactor smoothquant etc. in the future Also moved TORCH_VERSION_AFTER_{2_2/2_3/2_4} from torchao.quantization.utils to torchao.utils Test Plan: python test/integration/test_integration.py python test/quantization/test_quant_api.py python test/quantization/test_quant_primitives.py Reviewers: Subscribers: Tasks: Tags:
1 parent cd8f647 commit 8c8bc81

17 files changed

+431
-853
lines changed

Diff for: test/dtypes/test_nf4.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
241241
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
242242
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
243243
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)
244-
244+
245245
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
246246
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
247247
@parametrize("shape", [(16, 16), (32, 16)])
@@ -430,7 +430,7 @@ def test_to_cpu(self):
430430
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
431431
inner_tensor = getattr(nf4_tensor, attr)
432432
self.assertEqual(inner_tensor.device.type, "cpu")
433-
433+
434434

435435
instantiate_parametrized_tests(TestNF4Linear)
436436
instantiate_parametrized_tests(TestFSDPOps)

Diff for: test/integration/test_integration.py

+8-289
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@
2828
_replace_with_custom_fn_if_matches_filter,
2929
)
3030
from torchao.quantization.quant_primitives import (
31+
safe_int_mm,
32+
choose_qparams_affine,
33+
quantize_affine,
34+
dequantize_affine,
35+
MappingType,
36+
)
37+
from torchao.quantization.utils import (
3138
dequantize_per_channel,
3239
dequantize_per_tensor,
3340
dynamically_quantize_per_channel,
34-
dynamically_quantize_per_tensor,
35-
quant_int8_dynamic_linear,
3641
quant_int8_dynamic_per_token_linear,
3742
quantize_activation_per_token_absmax,
38-
safe_int_mm,
39-
dequantize_affine,
4043
)
4144

4245
from torchao.quantization.smoothquant import (
@@ -369,167 +372,7 @@ def test_debug_x_absmax(self):
369372
y1 = m(x0)
370373

371374

372-
class PythonQuantPrimitivesUnitTest(unittest.TestCase):
373-
def _test_dynamic_quant_per_tensor_numerics_impl(
374-
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme
375-
):
376-
x = torch.randn(256, dtype=float_dtype, device=device)
377-
y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor(
378-
x, qmin, qmax, int_dtype, qscheme
379-
)
380-
381-
# reference
382-
# quantize_per_tensor_dynamic doesn't work for half, so we cast there and back
383-
x_for_ref = x.half().float() if float_dtype == torch.float16 else x
384-
385-
# quantize_per_tensor_dynamic doesn't support qscheme, so we just do dynamic
386-
# quant manually with observers + static quant
387-
obs = MinMaxObserver(
388-
dtype=qint_dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax
389-
).to(device)
390-
obs(x_for_ref)
391-
ref_scale, ref_zero_point = obs.calculate_qparams()
392-
y_ref = torch.quantize_per_tensor(
393-
x_for_ref, ref_scale, ref_zero_point, qint_dtype
394-
)
395-
396-
# y_ref = torch.quantize_per_tensor_dynamic(x_for_ref, qint_dtype, False)
397-
# print(y_ref)
398-
if float_dtype == torch.float:
399-
assert torch.equal(y_vals, y_ref.int_repr())
400-
else:
401-
# numerics are not exactly aligned yet, off-by-one probably due
402-
# to rounding
403-
assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1
404-
torch.testing.assert_close(
405-
y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype)
406-
)
407-
if y_zero_point is not None:
408-
assert torch.equal(
409-
y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device)
410-
)
411-
else:
412-
self.assertTrue(y_ref.q_zero_point() == 0)
413-
414-
# dequantize and check again
415-
x_dq = dequantize_per_tensor(y_vals, y_scale, y_zero_point, float_dtype)
416-
y_ref_dq = y_ref.dequantize().to(float_dtype)
417-
if float_dtype == torch.float:
418-
torch.testing.assert_close(x_dq, y_ref_dq)
419-
else:
420-
sqnr = compute_error(x_dq, y_ref_dq)
421-
self.assertTrue(sqnr.item() > 45.0)
422-
423-
def test_dynamic_quant_per_tensor_numerics_cpu(self):
424-
# verifies that dynamic quant per tensor in plain pytorch matches
425-
# numerics of production AO code
426-
# TODO(future): test this on cpu-half, need to first make
427-
# torch.aminmax support half on cpu
428-
test_cases = (
429-
(
430-
0,
431-
255,
432-
torch.uint8,
433-
torch.quint8,
434-
torch.float32,
435-
"cpu",
436-
torch.per_tensor_affine,
437-
),
438-
(
439-
-128,
440-
127,
441-
torch.int8,
442-
torch.qint8,
443-
torch.float32,
444-
"cpu",
445-
torch.per_tensor_affine,
446-
),
447-
(
448-
-128,
449-
127,
450-
torch.int8,
451-
torch.qint8,
452-
torch.float32,
453-
"cpu",
454-
torch.per_tensor_symmetric,
455-
),
456-
(
457-
-127,
458-
127,
459-
torch.int8,
460-
torch.qint8,
461-
torch.float32,
462-
"cpu",
463-
torch.per_tensor_symmetric,
464-
),
465-
)
466-
for row in test_cases:
467-
self._test_dynamic_quant_per_tensor_numerics_impl(*row)
468-
469-
@unittest.skip("test case incorrect on A10G")
470-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
471-
def test_dynamic_quant_per_tensor_numerics_cuda(self):
472-
# verifies that dynamic quant per tensor in plain pytorch matches
473-
# numerics of production AO code
474-
test_cases = (
475-
(
476-
-128,
477-
127,
478-
torch.int8,
479-
torch.qint8,
480-
torch.float32,
481-
"cuda",
482-
torch.per_tensor_affine,
483-
),
484-
(
485-
-128,
486-
127,
487-
torch.int8,
488-
torch.qint8,
489-
torch.float16,
490-
"cuda",
491-
torch.per_tensor_affine,
492-
),
493-
(
494-
-128,
495-
127,
496-
torch.int8,
497-
torch.qint8,
498-
torch.float32,
499-
"cuda",
500-
torch.per_tensor_symmetric,
501-
),
502-
(
503-
-128,
504-
127,
505-
torch.int8,
506-
torch.qint8,
507-
torch.float16,
508-
"cuda",
509-
torch.per_tensor_symmetric,
510-
),
511-
(
512-
-127,
513-
127,
514-
torch.int8,
515-
torch.qint8,
516-
torch.float32,
517-
"cuda",
518-
torch.per_tensor_symmetric,
519-
),
520-
(
521-
-127,
522-
127,
523-
torch.int8,
524-
torch.qint8,
525-
torch.float16,
526-
"cuda",
527-
torch.per_tensor_symmetric,
528-
),
529-
)
530-
for row in test_cases:
531-
self._test_dynamic_quant_per_tensor_numerics_impl(*row)
532-
375+
class PythonQuantUtilOpUnitTest(unittest.TestCase):
533376
def _test_dynamic_quant_per_channel_numerics_impl(
534377
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
535378
):
@@ -705,130 +548,6 @@ def wrap_torch_int_mm(x, w):
705548
torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0)
706549
torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0)
707550

708-
def _test_qlinear_per_channel_numerics(
709-
self, x_shape, lin_shape, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
710-
):
711-
qconfig = torch.ao.quantization.per_channel_dynamic_qconfig
712-
713-
x = torch.randn(*x_shape, device=device, dtype=float_dtype)
714-
715-
# TODO: test bias true and false
716-
# Note: reference path only works on float because lack of aten quant primitives
717-
# support of half, so we cast back and forth to emulate
718-
lin_ref = (
719-
nn.Sequential(nn.Linear(*lin_shape))
720-
.eval()
721-
.to(float_dtype)
722-
.float()
723-
.to(device)
724-
)
725-
y_ref = lin_ref(x.float())
726-
weight = lin_ref[0].weight
727-
bias = lin_ref[0].bias
728-
729-
qconfig_mapping = QConfigMapping().set_global(qconfig)
730-
lin_ref_p = prepare_fx(lin_ref, qconfig_mapping, (torch.randn(1, 1),))
731-
lin_ref_q = convert_to_reference_fx(lin_ref_p)
732-
y_q_ref = lin_ref_q(x.float())
733-
734-
# scale, zp of weight (get from reference model)
735-
w_obs = qconfig.weight()
736-
w_obs(weight)
737-
lin_ref_w_scale, lin_ref_w_zp = w_obs.calculate_qparams()
738-
lin_ref_w_scale = lin_ref_w_scale.to(device).to(float_dtype)
739-
# print('lin_ref_w', 'scale', lin_ref_w_scale, 'zp', lin_ref_w_zp)
740-
741-
w_vals, _s, _z = dynamically_quantize_per_channel(
742-
getattr(lin_ref_q, "0").weight.to(float_dtype), -128, 127, torch.int8
743-
)
744-
w_vals = w_vals.t().contiguous()
745-
w_vals_sums = w_vals.sum(dim=0)
746-
747-
# do our version of the quantized linear operator
748-
y = quant_int8_dynamic_linear(
749-
x,
750-
qmin,
751-
qmax,
752-
int_dtype,
753-
w_vals,
754-
lin_ref_w_scale,
755-
w_vals_sums,
756-
bias,
757-
float_dtype,
758-
)
759-
760-
# print('y', y)
761-
# print('y_q_ref', y_q_ref)
762-
# print('y_ref', y_ref)
763-
764-
sqnr_ref = compute_error(y_ref, y_q_ref)
765-
sqnr_our = compute_error(y_ref, y)
766-
# print('sqnr_ref', sqnr_ref, 'sqnr_our', sqnr_our)
767-
# for large shapes, sqnr can be in the high 30s for float32 and float16
768-
self.assertTrue(sqnr_our.item() >= 37.5)
769-
770-
def test_qlinear_per_channel_numerics_cpu(self):
771-
# Note: the AO codebase doesn't easily support qint8 activations,
772-
# so the test cases below are for the quant primitives defined in
773-
# this file only. The AO reference is using quint8 here.
774-
test_cases = (
775-
((2, 3), (3, 4), 0, 255, torch.uint8, torch.quint8, torch.float32, "cpu"),
776-
((2, 3), (3, 4), -128, 127, torch.int8, torch.qint8, torch.float32, "cpu"),
777-
)
778-
for test_case in test_cases:
779-
self._test_qlinear_per_channel_numerics(*test_case)
780-
781-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
782-
def test_qlinear_per_channel_numerics_cuda(self):
783-
test_cases = (
784-
# Note: torch._int_mm needs int8 activations, so we don't test uint8
785-
# activations on CUDA at all
786-
(
787-
(32, 32),
788-
(32, 16),
789-
-128,
790-
127,
791-
torch.int8,
792-
torch.qint8,
793-
torch.float32,
794-
"cuda",
795-
),
796-
(
797-
(32, 32),
798-
(32, 16),
799-
-128,
800-
127,
801-
torch.int8,
802-
torch.qint8,
803-
torch.float16,
804-
"cuda",
805-
),
806-
# a large shape from LLaMa 1.5B - currently fails for float16
807-
(
808-
(17, 4096),
809-
(4096, 1536),
810-
-128,
811-
127,
812-
torch.int8,
813-
torch.qint8,
814-
torch.float32,
815-
"cuda",
816-
),
817-
(
818-
(17, 4096),
819-
(4096, 1536),
820-
-128,
821-
127,
822-
torch.int8,
823-
torch.qint8,
824-
torch.float16,
825-
"cuda",
826-
),
827-
)
828-
for test_case in test_cases:
829-
self._test_qlinear_per_channel_numerics(*test_case)
830-
831-
832551
class TestSubclass(unittest.TestCase):
833552
@run_supported_device_dtype
834553
def _test_dequantize_impl(

Diff for: test/quantization/test_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
fake_quantize_per_channel_group,
1919
fake_quantize_per_token,
2020
)
21-
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
21+
from torchao.quantization.utils import get_group_qparams_symmetric
2222
from torchao.utils import TORCH_VERSION_AFTER_2_4
2323

2424

0 commit comments

Comments
 (0)