diff --git a/nemo/quant/pact.py b/nemo/quant/pact.py index 3c3554d..1fa4768 100644 --- a/nemo/quant/pact.py +++ b/nemo/quant/pact.py @@ -51,6 +51,7 @@ def pact_quantized_requantize(t, eps_in, eps_out, D=1, exclude_requant_rounding= # re-quantize from a lower precision (larger eps_in) to a higher precision (lower eps_out) def pact_integer_requantize(t, eps_in, eps_out, D=1): D = D.clone().detach().to(eps_in.device) + # D = torch.tensor(D, device=eps_in.device) eps_ratio = (D*eps_in/eps_out).round() device = t.device return torch.as_tensor((t.clone().detach().type(torch.int64) * eps_ratio.clone().detach().type(torch.int64) // D), dtype=torch.float32, device=device) @@ -372,7 +373,7 @@ def set_static_precision(self, limit_at_32_bits=True, **kwargs): if not limit_at_32_bits: self.D = D else: - self.D = min(D, 2.0**(32-1-(self.precision.get_bits()))) + self.D = min(D, torch.tensor(2.0**(32-1-(self.precision.get_bits())), device = D.device)) def get_output_eps(self, eps_in): r"""Get the output quantum (:math:`\varepsilon`) given the input one. @@ -610,8 +611,9 @@ def set_output_eps(self, limit_at_32_bits=True, **kwargs): if not limit_at_32_bits: self.D = D else: - self.D = min(D, 2**(32-(self.precision.get_bits()))) - + # self.D = min(D, 2**(32-(self.precision.get_bits()))) + self.D = min(D, torch.tensor([2.0**(32-1-(self.precision.get_bits()))], device = D.device, dtype=torch.int64)) + def get_output_eps(self, eps_in): r"""Get the output quantum (:math:`\varepsilon`) given the input one.