Open
Description
Current status (without onnx quantization)
For pytorch-lightning user:
pl_model = Trainer.compile(model, loss, optim) # skip if you have a pl model
trainer.fit(pl_model, dataloader)
pl_model_quantized = trainer.quantize(pl_model, dataloader)
pl_model_quantized(x) # quantized inference
For pytorch user (potentially, we do not support it yet but it can be valid simply):
# >>>>>>>> start of pytorch training loop >>>>>>>>>>>
# ...
# <<<<<<<< end of pytorch training loop <<<<<<<<<<<<
model_quantized = trainer.quantize(model, dataloader)
model_quantized(x) # quantized inference
Issue
pl_model_quantized
andmodel_quantized
aretorch.fx.graph_module.GraphModule
. An unfamilar type to users. Users can not:
-
use onnx.export to trace the quantized model.
-
use normal way to save or load the model.
-
continually train on this quantized model
- It is extremly hard to integrate onnx quantization with INC to this API
New revised API usage (with onnx quantization)
For pytorch-lightning user:
pl_model = Trainer.compile(model, loss, optim, onnx=True/False) # skip if you have a pl model
trainer.fit(pl_model, dataloader)
pl_model = trainer.quantize(pl_model, dataloader, onnx=True/False)
For pytorch user:
model = Trainer.compile(model, onnx=True/False)
# >>>>>>>> start of pytorch training loop >>>>>>>>>>>
# ...
# <<<<<<<< end of pytorch training loop <<<<<<<<<<<<
model = trainer.quantize(model, dataloader, onnx=True/False)
pl_model
and model
are still pytorch-lightning model, then the prediction can be
# predict with pytorch fp32
pl_model.eval()
with torch.no_grad():
pl_model(x)
# or
pl_model.inference(x, backend=None)
# predict with pytorch int8
pl_model.eval(quantize=True)
with torch.no_grad():
pl_model(x)
# or
pl_model.inference(x, backend=None, quantize=True)
# predict with onnx fp32
pl_model.eval_onnx()
with torch.no_grad():
pl_model(x)
# or
pl_model.inference(x, backend="onnx")
# predict with onnx int8
pl_model.eval_onnx(quantize=True)
with torch.no_grad():
pl_model(x)
# or
pl_model.inference(x, backend="onnx", quantize=True)
We should also provide an option to return the "raw" result to users, where onnx quantization will return a onnx model( onnx.onnx_ml_pb2.ModelProto
) and pytorch fx quantization will return a fx model( torch.fx.GraphModule
).
model = trainer.quantize(..., raw_return=True) # defaultly False
So basically three PRs will be raised separately for this issue:
- change the
Trainer.quantize(..., raw_return=False)
to return a pl model and support the easy inference api (.eval(quantized=True)
) Nano quantize inference API for pytorch #3866 - add support to onnx's quantization
- add support of save and load