diff --git a/.gitignore b/.gitignore index 9e6f1a45cc..2f80c81d5a 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,4 @@ tests/mylib.onnxlib **/serde_test_profiles/* tools/ort_rewriter_profiling/.logs/* tools/ort_rewriter_profiling/onnx_models/* +/dump_TestOperatorsOnnxrt diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b8535d46c7..21be2944b6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2093,6 +2093,7 @@ def _aten_convolution_onnx( return result +@torch_op("aten::convolution_backward", trace_only=True) def aten_convolution_backward( grad_output: TensorType, input: TensorType, @@ -2108,7 +2109,87 @@ def aten_convolution_backward( ) -> tuple[TensorType, TensorType, TensorType]: """convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)""" - raise NotImplementedError() + # Compute weight.grad : dW_t = X_t * dZ_t + input_t = op.Transpose(input, perm=[1, 0, 2, 3]) + dz_t = op.Transpose(grad_output, perm=[1, 0, 2, 3]) + dw_t = op.Conv(input_t, dz_t) + dw = op.Transpose(dw_t, perm=[1, 0, 2, 3]) + axes = op.Constant(value_ints=[0, 2, 3]) + db = op.ReduceSum(grad_output, axes, keepdims=0) + + # Compute x.grad: dx = dZ(+0) * W_rot180 + # Assume: grad_output=(20,13,48,38) + z_height = op.Shape(grad_output, start=2, end=3) # 48 + z_width = op.Shape(grad_output, start=3, end=4) # 38 + + if stride[0] != 1 or stride[1] != 1: + raise NotImplementedError("stride != 1 is not supported yet") + # if stride[0] != 1: # dilation + # dz_height = z_height * stride[0] - stride[0] + 1 + # dz_width = z_width * stride[1] - stride[1] + 1 + # pos = _help(z_height, dz_width, stride) + # pos = [] + # for j in range(z_height): + # for i in range(0, dz_width, stride[1]): + # pos.append(i + j * dz_width * stride[0]) + + # index_tensor = op.Constant(value_ints=pos) + # index_tensor = op.Reshape(index_tensor, z_shape) + # # this should not work because the kernel_shape is attribute + # dz = op.MaxUnpool(grad_output, index_tensor, kernel_shape=[dz_height - z_height + 1, dz_width - z_width + 1]) + + # # Computing padding size + # Assume: input=(20,16,50,40) + x_height = op.Shape(input, start=2, end=3) # 50 + x_width = op.Shape(input, start=3, end=4) # 40 + # Assume: weight=(13,16,3,3) + w_height = op.Shape(weight, start=2, end=3) # 3 + w_width = op.Shape(weight, start=3, end=4) # 3 + tmp_int = x_height - z_height + w_height - 1 # 50-48+3-1=4 + tmp_float = op.Cast(tmp_int, to=FLOAT.dtype) + pad_height = op.Cast( + op.Div(tmp_float, op.Constant(value_floats=[2.0])), to=INT64.dtype + ) # 4/2=2 + tmp_int = x_width - z_width + w_width - 1 # 40-38+3-1=4 + tmp_float = op.Cast(tmp_int, to=FLOAT.dtype) + pad_width = op.Cast( + op.Div(tmp_float, op.Constant(value_floats=[2.0])), to=INT64.dtype + ) # 4/2=2 + pads = op.Concat( # [0,0,2,2,0,0,2,2] + # begin of dim0, dim1, dim2, dim3 + op.Constant(value_ints=[0]), + op.Constant(value_ints=[0]), + pad_height, + pad_width, + # end of dim0, dim1, dim2, dim3 + op.Constant(value_ints=[0]), + op.Constant(value_ints=[0]), + pad_height, + pad_width, + axis=0, + ) + dz_pad = op.Pad(grad_output, pads) # enlarge the grad_output to (20,13,52,42) + + # Transpose from (13,16,3,3) to (16,13,3,3) + w_transpose = op.Transpose(weight, perm=[1, 0, 2, 3]) + # Rotate weight (13,16,3,3) with 180 degree: np.rot90(w, 2) -> (13,6,3,3) + w_shape_0 = op.Shape(w_transpose, start=0, end=1) # 13 + w_shape_1 = op.Shape(w_transpose, start=1, end=2) # 6 + w_shape_2 = op.Constant(value_ints=[1]) # 1 + w_shape_3 = op.Constant(value_ints=[-1]) # -1 + w_shape_new = op.Concat(w_shape_0, w_shape_1, w_shape_2, w_shape_3, axis=0) # (13,16,1,-1) + w_new = op.Reshape(w_transpose, w_shape_new) # reshape to (13,16,1,-1) + # reverse the values in the last dim (axes=3), e.g. [1,2,3....,9] -> [9,...,3,2,1] + starts = op.Constant(value_int=[-1]) + ends = op.Constant(value_int=[-1000]) + axes = op.Constant(value_int=[3]) + steps = op.Constant(value_int=[-1]) + w_slice = op.Slice(w_new, starts, ends, axes, steps) # weight[:,:,:,-1:-1000:-1] + weight_rot180 = op.Reshape(w_slice, op.Shape(w_transpose)) # reshape to (13,16,3,3) + # dx = dz(pad0) * w(rot180) + dx = op.Conv(dz_pad, weight_rot180) + # Todo: when dx is bigger than input, e.g. 29x29 vs. 28x28, need to delete last row and column of dx + return dx, dw, db def aten_convolution_backward_overrideable( @@ -4659,7 +4740,7 @@ def aten_le(self: TReal, other: TReal) -> BOOL: return op.LessOrEqual(self, other) -@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le")) +@torch_op(("aten::le.Scalar", "aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le")) def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -6583,10 +6664,18 @@ def aten_prelu_backward( raise NotImplementedError() -def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType: +@torch_op(("aten::prod"), trace_only=True) +def aten_prod(self: TReal, dtype: Optional[int] = None) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - raise NotImplementedError() + return op.ReduceProd(self) + + +@torch_op("aten::prod.dim_int", trace_only=True) +def aten_prod_dim(self: TReal, dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> TReal: + """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" + + return op.ReduceProd(self, axes=dim, keepdims=keepdim) def aten_promote_types(type1: int, type2: int) -> int: diff --git a/onnxscript/tools/training_helper.py b/onnxscript/tools/training_helper.py index 785b2e6fb3..6fe245ba09 100644 --- a/onnxscript/tools/training_helper.py +++ b/onnxscript/tools/training_helper.py @@ -2,13 +2,19 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + +import glob +import os +from typing import Any + import torch from torch.onnx import ExportOptions from torch.onnx import _OrtBackend as OrtBackend from torch.onnx import _OrtBackendOptions as OrtBackendOptions -def make_aot_ort(dynamic: bool = False): +def make_aot_ort(dynamic: bool = False) -> Any: """Implements an autograd backend for torch.compile based on onnxrt backend.""" export_options = ExportOptions(dynamic_shapes=dynamic) options = OrtBackendOptions(export_options=export_options) @@ -16,8 +22,39 @@ def make_aot_ort(dynamic: bool = False): return ort_backend -def train_loop(model, *args, loss_fn=None, optimizer=None): - """Implements a training loop to be used in tests.""" +def train_loop( + model: Any, + *args, + loss_fn: Any | None = None, + optimizer: Any | None = None, + dump_onnx_models: bool = False, + dump_prefix: str = "dump_train_loop", + dump_clean_first: bool = True, +) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]: + """Implements a training loop to be used in tests. + The function returns the forward output and gradients in a tuple. + + if dump_onnx_models is True, the function returns the forward output, + the gradients in a tuple and the generated onnx_files. + If there is no graph break, there should be + two graphs, one for forward, one for backward. + + Args: + model: pytorch model + args: inputs + loss_fn: loss function, default is MSELoss + optimizer: optimizer, default is SGD + dump_onnx_models: dumps the model onnxrt backend is producing + dump_prefix: names will be `0.onnx`, `1.onnx`, ... + dump_clean_first: clean all files starting with the given prefix + + Returns: + - the forward outputs + - the backwards gradients + - the dumped onnw models, 2 at least unless the forward, backward + were called before this function is executed or if the model + is not a compiled model + """ if loss_fn is None: loss_fn = torch.nn.MSELoss() @@ -28,6 +65,16 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): # Unnecessary in this situation but added for best practices model.train() + if dump_onnx_models: + if dump_clean_first: + names = glob.glob(f"{dump_prefix}*") + for name in names: + os.remove(name) + + old_value = os.environ.get("ONNXRT_DUMP_PATH", None) + os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_forward" + existing_files = glob.glob(f"{dump_prefix}*.onnx") + # Compute prediction and loss pred = model(*args) if isinstance(pred, tuple): @@ -39,6 +86,8 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): loss = loss_fn(v, torch.ones_like(v)) # Backpropagation + if dump_onnx_models: + os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_backward" loss.backward() optimizer.step() # skip that part to retrieve the gradients @@ -47,4 +96,14 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): # returns the gradients res = tuple(p.grad for p in model.parameters() if p.grad is not None) assert len(res) > 0, f"No gradient, loss is {loss}" - return res + + if dump_onnx_models: + if old_value is None: + del os.environ["ONNXRT_DUMP_PATH"] + else: + os.environ["ONNXRT_DUMP_PATH"] = old_value + new_files = glob.glob(f"{dump_prefix}*.onnx") + added_files = set(new_files) - set(existing_files) + return pred, res, [f for f in new_files if f in added_files] + + return pred, res diff --git a/tests/function_libs/torch_lib/backward_test.py b/tests/function_libs/torch_lib/backward_test.py new file mode 100644 index 0000000000..360ca5312b --- /dev/null +++ b/tests/function_libs/torch_lib/backward_test.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import copy +import sys +import unittest + +import torch + +import onnxscript.tools.training_helper +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.llama +from onnxscript._internal.version_utils import has_transformers, torch_older_than + + +class TestBackward(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + def test_backward_working(self): + class SimpleCNNN(torch.nn.Module): + def __init__(self): + super().__init__() + + self.fc1 = torch.nn.Linear(14, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.fc1(x)) + + input_tensors = (torch.randn(1, 1, 14, 14),) + model = SimpleCNNN() + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_testbw_working", + dump_clean_first=True, + ) + torch.testing.assert_close(expected_results[0], results[0], atol=1e-5, rtol=1e-5) + + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + # @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + # @unittest.skipIf(True, reason="aten.conv_backward not implemented yet.") + def test_backward_conv(self): + class SimpleCNNN(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv1 = torch.nn.Conv2d( + in_channels=1, + out_channels=2, + kernel_size=3, + padding=(0, 0), # not support padding=1, will do it soon + ) + self.fc1 = torch.nn.Linear(12, 10) + + def forward(self, x): + y = torch.nn.functional.relu(self.conv1(x)) + z = self.fc1(y) + return z + + input_tensors = (torch.randn(1, 1, 14, 14),) + model = SimpleCNNN() + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_testbw_conv", + dump_clean_first=True, + ) + torch.testing.assert_close(expected_results[0], results[0], atol=1e-5, rtol=1e-5) + + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2)