|
28 | 28 | _replace_with_custom_fn_if_matches_filter,
|
29 | 29 | )
|
30 | 30 | from torchao.quantization.quant_primitives import (
|
| 31 | + safe_int_mm, |
| 32 | + choose_qparams_affine, |
| 33 | + quantize_affine, |
| 34 | + dequantize_affine, |
| 35 | + MappingType, |
| 36 | +) |
| 37 | +from torchao.quantization.utils import ( |
31 | 38 | dequantize_per_channel,
|
32 | 39 | dequantize_per_tensor,
|
33 | 40 | dynamically_quantize_per_channel,
|
34 |
| - dynamically_quantize_per_tensor, |
35 |
| - quant_int8_dynamic_linear, |
36 | 41 | quant_int8_dynamic_per_token_linear,
|
37 | 42 | quantize_activation_per_token_absmax,
|
38 |
| - safe_int_mm, |
39 |
| - dequantize_affine, |
40 | 43 | )
|
41 | 44 |
|
42 | 45 | from torchao.quantization.smoothquant import (
|
@@ -369,167 +372,7 @@ def test_debug_x_absmax(self):
|
369 | 372 | y1 = m(x0)
|
370 | 373 |
|
371 | 374 |
|
372 |
| -class PythonQuantPrimitivesUnitTest(unittest.TestCase): |
373 |
| - def _test_dynamic_quant_per_tensor_numerics_impl( |
374 |
| - self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme |
375 |
| - ): |
376 |
| - x = torch.randn(256, dtype=float_dtype, device=device) |
377 |
| - y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor( |
378 |
| - x, qmin, qmax, int_dtype, qscheme |
379 |
| - ) |
380 |
| - |
381 |
| - # reference |
382 |
| - # quantize_per_tensor_dynamic doesn't work for half, so we cast there and back |
383 |
| - x_for_ref = x.half().float() if float_dtype == torch.float16 else x |
384 |
| - |
385 |
| - # quantize_per_tensor_dynamic doesn't support qscheme, so we just do dynamic |
386 |
| - # quant manually with observers + static quant |
387 |
| - obs = MinMaxObserver( |
388 |
| - dtype=qint_dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax |
389 |
| - ).to(device) |
390 |
| - obs(x_for_ref) |
391 |
| - ref_scale, ref_zero_point = obs.calculate_qparams() |
392 |
| - y_ref = torch.quantize_per_tensor( |
393 |
| - x_for_ref, ref_scale, ref_zero_point, qint_dtype |
394 |
| - ) |
395 |
| - |
396 |
| - # y_ref = torch.quantize_per_tensor_dynamic(x_for_ref, qint_dtype, False) |
397 |
| - # print(y_ref) |
398 |
| - if float_dtype == torch.float: |
399 |
| - assert torch.equal(y_vals, y_ref.int_repr()) |
400 |
| - else: |
401 |
| - # numerics are not exactly aligned yet, off-by-one probably due |
402 |
| - # to rounding |
403 |
| - assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1 |
404 |
| - torch.testing.assert_close( |
405 |
| - y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype) |
406 |
| - ) |
407 |
| - if y_zero_point is not None: |
408 |
| - assert torch.equal( |
409 |
| - y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device) |
410 |
| - ) |
411 |
| - else: |
412 |
| - self.assertTrue(y_ref.q_zero_point() == 0) |
413 |
| - |
414 |
| - # dequantize and check again |
415 |
| - x_dq = dequantize_per_tensor(y_vals, y_scale, y_zero_point, float_dtype) |
416 |
| - y_ref_dq = y_ref.dequantize().to(float_dtype) |
417 |
| - if float_dtype == torch.float: |
418 |
| - torch.testing.assert_close(x_dq, y_ref_dq) |
419 |
| - else: |
420 |
| - sqnr = compute_error(x_dq, y_ref_dq) |
421 |
| - self.assertTrue(sqnr.item() > 45.0) |
422 |
| - |
423 |
| - def test_dynamic_quant_per_tensor_numerics_cpu(self): |
424 |
| - # verifies that dynamic quant per tensor in plain pytorch matches |
425 |
| - # numerics of production AO code |
426 |
| - # TODO(future): test this on cpu-half, need to first make |
427 |
| - # torch.aminmax support half on cpu |
428 |
| - test_cases = ( |
429 |
| - ( |
430 |
| - 0, |
431 |
| - 255, |
432 |
| - torch.uint8, |
433 |
| - torch.quint8, |
434 |
| - torch.float32, |
435 |
| - "cpu", |
436 |
| - torch.per_tensor_affine, |
437 |
| - ), |
438 |
| - ( |
439 |
| - -128, |
440 |
| - 127, |
441 |
| - torch.int8, |
442 |
| - torch.qint8, |
443 |
| - torch.float32, |
444 |
| - "cpu", |
445 |
| - torch.per_tensor_affine, |
446 |
| - ), |
447 |
| - ( |
448 |
| - -128, |
449 |
| - 127, |
450 |
| - torch.int8, |
451 |
| - torch.qint8, |
452 |
| - torch.float32, |
453 |
| - "cpu", |
454 |
| - torch.per_tensor_symmetric, |
455 |
| - ), |
456 |
| - ( |
457 |
| - -127, |
458 |
| - 127, |
459 |
| - torch.int8, |
460 |
| - torch.qint8, |
461 |
| - torch.float32, |
462 |
| - "cpu", |
463 |
| - torch.per_tensor_symmetric, |
464 |
| - ), |
465 |
| - ) |
466 |
| - for row in test_cases: |
467 |
| - self._test_dynamic_quant_per_tensor_numerics_impl(*row) |
468 |
| - |
469 |
| - @unittest.skip("test case incorrect on A10G") |
470 |
| - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
471 |
| - def test_dynamic_quant_per_tensor_numerics_cuda(self): |
472 |
| - # verifies that dynamic quant per tensor in plain pytorch matches |
473 |
| - # numerics of production AO code |
474 |
| - test_cases = ( |
475 |
| - ( |
476 |
| - -128, |
477 |
| - 127, |
478 |
| - torch.int8, |
479 |
| - torch.qint8, |
480 |
| - torch.float32, |
481 |
| - "cuda", |
482 |
| - torch.per_tensor_affine, |
483 |
| - ), |
484 |
| - ( |
485 |
| - -128, |
486 |
| - 127, |
487 |
| - torch.int8, |
488 |
| - torch.qint8, |
489 |
| - torch.float16, |
490 |
| - "cuda", |
491 |
| - torch.per_tensor_affine, |
492 |
| - ), |
493 |
| - ( |
494 |
| - -128, |
495 |
| - 127, |
496 |
| - torch.int8, |
497 |
| - torch.qint8, |
498 |
| - torch.float32, |
499 |
| - "cuda", |
500 |
| - torch.per_tensor_symmetric, |
501 |
| - ), |
502 |
| - ( |
503 |
| - -128, |
504 |
| - 127, |
505 |
| - torch.int8, |
506 |
| - torch.qint8, |
507 |
| - torch.float16, |
508 |
| - "cuda", |
509 |
| - torch.per_tensor_symmetric, |
510 |
| - ), |
511 |
| - ( |
512 |
| - -127, |
513 |
| - 127, |
514 |
| - torch.int8, |
515 |
| - torch.qint8, |
516 |
| - torch.float32, |
517 |
| - "cuda", |
518 |
| - torch.per_tensor_symmetric, |
519 |
| - ), |
520 |
| - ( |
521 |
| - -127, |
522 |
| - 127, |
523 |
| - torch.int8, |
524 |
| - torch.qint8, |
525 |
| - torch.float16, |
526 |
| - "cuda", |
527 |
| - torch.per_tensor_symmetric, |
528 |
| - ), |
529 |
| - ) |
530 |
| - for row in test_cases: |
531 |
| - self._test_dynamic_quant_per_tensor_numerics_impl(*row) |
532 |
| - |
| 375 | +class PythonQuantUtilOpUnitTest(unittest.TestCase): |
533 | 376 | def _test_dynamic_quant_per_channel_numerics_impl(
|
534 | 377 | self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
|
535 | 378 | ):
|
@@ -705,130 +548,6 @@ def wrap_torch_int_mm(x, w):
|
705 | 548 | torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0)
|
706 | 549 | torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0)
|
707 | 550 |
|
708 |
| - def _test_qlinear_per_channel_numerics( |
709 |
| - self, x_shape, lin_shape, qmin, qmax, int_dtype, qint_dtype, float_dtype, device |
710 |
| - ): |
711 |
| - qconfig = torch.ao.quantization.per_channel_dynamic_qconfig |
712 |
| - |
713 |
| - x = torch.randn(*x_shape, device=device, dtype=float_dtype) |
714 |
| - |
715 |
| - # TODO: test bias true and false |
716 |
| - # Note: reference path only works on float because lack of aten quant primitives |
717 |
| - # support of half, so we cast back and forth to emulate |
718 |
| - lin_ref = ( |
719 |
| - nn.Sequential(nn.Linear(*lin_shape)) |
720 |
| - .eval() |
721 |
| - .to(float_dtype) |
722 |
| - .float() |
723 |
| - .to(device) |
724 |
| - ) |
725 |
| - y_ref = lin_ref(x.float()) |
726 |
| - weight = lin_ref[0].weight |
727 |
| - bias = lin_ref[0].bias |
728 |
| - |
729 |
| - qconfig_mapping = QConfigMapping().set_global(qconfig) |
730 |
| - lin_ref_p = prepare_fx(lin_ref, qconfig_mapping, (torch.randn(1, 1),)) |
731 |
| - lin_ref_q = convert_to_reference_fx(lin_ref_p) |
732 |
| - y_q_ref = lin_ref_q(x.float()) |
733 |
| - |
734 |
| - # scale, zp of weight (get from reference model) |
735 |
| - w_obs = qconfig.weight() |
736 |
| - w_obs(weight) |
737 |
| - lin_ref_w_scale, lin_ref_w_zp = w_obs.calculate_qparams() |
738 |
| - lin_ref_w_scale = lin_ref_w_scale.to(device).to(float_dtype) |
739 |
| - # print('lin_ref_w', 'scale', lin_ref_w_scale, 'zp', lin_ref_w_zp) |
740 |
| - |
741 |
| - w_vals, _s, _z = dynamically_quantize_per_channel( |
742 |
| - getattr(lin_ref_q, "0").weight.to(float_dtype), -128, 127, torch.int8 |
743 |
| - ) |
744 |
| - w_vals = w_vals.t().contiguous() |
745 |
| - w_vals_sums = w_vals.sum(dim=0) |
746 |
| - |
747 |
| - # do our version of the quantized linear operator |
748 |
| - y = quant_int8_dynamic_linear( |
749 |
| - x, |
750 |
| - qmin, |
751 |
| - qmax, |
752 |
| - int_dtype, |
753 |
| - w_vals, |
754 |
| - lin_ref_w_scale, |
755 |
| - w_vals_sums, |
756 |
| - bias, |
757 |
| - float_dtype, |
758 |
| - ) |
759 |
| - |
760 |
| - # print('y', y) |
761 |
| - # print('y_q_ref', y_q_ref) |
762 |
| - # print('y_ref', y_ref) |
763 |
| - |
764 |
| - sqnr_ref = compute_error(y_ref, y_q_ref) |
765 |
| - sqnr_our = compute_error(y_ref, y) |
766 |
| - # print('sqnr_ref', sqnr_ref, 'sqnr_our', sqnr_our) |
767 |
| - # for large shapes, sqnr can be in the high 30s for float32 and float16 |
768 |
| - self.assertTrue(sqnr_our.item() >= 37.5) |
769 |
| - |
770 |
| - def test_qlinear_per_channel_numerics_cpu(self): |
771 |
| - # Note: the AO codebase doesn't easily support qint8 activations, |
772 |
| - # so the test cases below are for the quant primitives defined in |
773 |
| - # this file only. The AO reference is using quint8 here. |
774 |
| - test_cases = ( |
775 |
| - ((2, 3), (3, 4), 0, 255, torch.uint8, torch.quint8, torch.float32, "cpu"), |
776 |
| - ((2, 3), (3, 4), -128, 127, torch.int8, torch.qint8, torch.float32, "cpu"), |
777 |
| - ) |
778 |
| - for test_case in test_cases: |
779 |
| - self._test_qlinear_per_channel_numerics(*test_case) |
780 |
| - |
781 |
| - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
782 |
| - def test_qlinear_per_channel_numerics_cuda(self): |
783 |
| - test_cases = ( |
784 |
| - # Note: torch._int_mm needs int8 activations, so we don't test uint8 |
785 |
| - # activations on CUDA at all |
786 |
| - ( |
787 |
| - (32, 32), |
788 |
| - (32, 16), |
789 |
| - -128, |
790 |
| - 127, |
791 |
| - torch.int8, |
792 |
| - torch.qint8, |
793 |
| - torch.float32, |
794 |
| - "cuda", |
795 |
| - ), |
796 |
| - ( |
797 |
| - (32, 32), |
798 |
| - (32, 16), |
799 |
| - -128, |
800 |
| - 127, |
801 |
| - torch.int8, |
802 |
| - torch.qint8, |
803 |
| - torch.float16, |
804 |
| - "cuda", |
805 |
| - ), |
806 |
| - # a large shape from LLaMa 1.5B - currently fails for float16 |
807 |
| - ( |
808 |
| - (17, 4096), |
809 |
| - (4096, 1536), |
810 |
| - -128, |
811 |
| - 127, |
812 |
| - torch.int8, |
813 |
| - torch.qint8, |
814 |
| - torch.float32, |
815 |
| - "cuda", |
816 |
| - ), |
817 |
| - ( |
818 |
| - (17, 4096), |
819 |
| - (4096, 1536), |
820 |
| - -128, |
821 |
| - 127, |
822 |
| - torch.int8, |
823 |
| - torch.qint8, |
824 |
| - torch.float16, |
825 |
| - "cuda", |
826 |
| - ), |
827 |
| - ) |
828 |
| - for test_case in test_cases: |
829 |
| - self._test_qlinear_per_channel_numerics(*test_case) |
830 |
| - |
831 |
| - |
832 | 551 | class TestSubclass(unittest.TestCase):
|
833 | 552 | @run_supported_device_dtype
|
834 | 553 | def _test_dequantize_impl(
|
|
0 commit comments