diff --git a/.github/workflows/ci-platform-siracusa.yml b/.github/workflows/ci-platform-siracusa.yml index 7c6a5f754..f59f7fa88 100644 --- a/.github/workflows/ci-platform-siracusa.yml +++ b/.github/workflows/ci-platform-siracusa.yml @@ -53,7 +53,15 @@ jobs: testBacktracking testFloatAdder testFloatGEMM + testFloat2DConvolution + testFloat2DConvolutionBias + testFloat2DConvolutionZeroBias + + testFloat2DDWConvolution + testFloat2DDWConvolutionBias + testFloat2DDWConvolutionZeroBias + testFloatLayerNorm testFloatRelu testFloatMaxPool @@ -64,6 +72,7 @@ jobs: Quant Dequant testFloatReduceSum + testFloatReshapeWithSkipConnection testFloatSoftmaxGrad testFloatSoftmaxCrossEntropy testFloatSoftmaxCrossEntropyGrad @@ -87,4 +96,5 @@ jobs: CCT/CCT_1_16_16_8 CCT/CCT_2_32_32_128_Opset20 testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8 + testFloatDemoTinyViT num-cores: 8 diff --git a/CHANGELOG.md b/CHANGELOG.md index 5421cdf52..6ae6917b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid ## Unreleased (Planned Release Target: v0.2.1) ### List of Pull Requests +- TinyViT on non-tiled Siracusa [#117](https://github.com/pulp-platform/Deeploy/pull/117) - Refactor Logging for Improved Debugging [#115](https://github.com/pulp-platform/Deeploy/pull/115) - Add reuse-tool as an SPDX license header linter [#113](https://github.com/pulp-platform/Deeploy/pull/113) - Bug fixes, API Cleanup and Reduce Compiler Warning on PULP [#112](https://github.com/pulp-platform/Deeploy/pull/112) @@ -17,6 +18,13 @@ This file contains the changelog for the Deeploy project. The changelog is divid - Fix `Unsqueeze` Op. when using ONNX opset 13 or higher (from attribute to input) [#119](https://github.com/pulp-platform/Deeploy/pull/119) ### Added +- PULP 2D FP DW conv Im2Col template and kernel, with bias support. +- Bias support for PULP 2D FP regular conv Im2Col in template & kernel. +- PULP FP DW conv 2D parser. +- FP conv 2D (simple & DW), reshape & skip connection, and TinyViT demo tests to the non-tiled Siracusa CI pipeline. +- FP bindings and mappings for PULP slice, DW conv 2D, and reduce mean operations. +- FP PULP DW conv lowering optimization pass similar to the existent one for integer version. +- RemoveEmptyConvBiasPass to the PULP optimizer. - Add manual type inference feature (CLI: `--input-type-map`/`--input-offset-map`) to resolve ambiguities when test inputs are not representative enough - Added a `testTypeInferenceDifferentTypes` test case to validate type inference for different input types - Added `_mangleNodeNames` function to avoid duplicate node mappings @@ -48,6 +56,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid - Memory/I/O summaries and input/output logging in deployers ### Changed +- Reduced size of reshape & skip connection test, for non-tiled Siracusa memory compatibility. - Replaced platform-specific tags (`*-amd64`, `*-arm64`) with direct digest references in `Noelware/docker-manifest-action`. - mchan HAL is now reduced to bare-bones - refactor of the IntrospectiveCodeTransformation to work on the Mako template @@ -75,6 +84,10 @@ This file contains the changelog for the Deeploy project. The changelog is divid - Deployer workflow now uses `prepare(...)` instead of `generateFunction(...)`. ### Fixed +- Fixed bug in alias_of node parameter handling, that takes care of the lifetime of buffers in skip connection situations. +- Fixed bug for non-batched elements in the PULPOpen FP GEMM and matmul templates. +- Added underscore to the beginning of closure names to avoid naming issues when they start with unsupported first characters (like numbers). +- Data types in the PULPOpen FP add and mul templates. - Prevent node duplication for graphs generated via GraphSurgeon - Resolved issue with missing `id` in the `Build Cache for Docker` step, used in the `Inject build-cache` step. - Fix license CI check and prevent potential issues with `jq` installation diff --git a/Deeploy/CommonExtensions/CodeTransformationPasses/Closure.py b/Deeploy/CommonExtensions/CodeTransformationPasses/Closure.py index c5f9c883a..59a649316 100644 --- a/Deeploy/CommonExtensions/CodeTransformationPasses/Closure.py +++ b/Deeploy/CommonExtensions/CodeTransformationPasses/Closure.py @@ -155,7 +155,8 @@ def apply(self, executionBlock: ExecutionBlock, name: str, verbose: CodeGenVerbosity = _NoVerbosity) -> Tuple[NetworkContext, ExecutionBlock]: - self.closureName = name + self.closureSuffix + # Add underscore to avoid name issues when beginning with problematic characters (like numbers) + self.closureName = "_" + name + self.closureSuffix self.functionCall = executionBlock.generate(ctxt) self._generateClosureStruct(ctxt, executionBlock) ctxt = self._generateClosureCtxt(ctxt, name) diff --git a/Deeploy/CommonExtensions/NetworkDeployers/SignPropDeployer.py b/Deeploy/CommonExtensions/NetworkDeployers/SignPropDeployer.py index 7a9fbea1a..828dd3b17 100644 --- a/Deeploy/CommonExtensions/NetworkDeployers/SignPropDeployer.py +++ b/Deeploy/CommonExtensions/NetworkDeployers/SignPropDeployer.py @@ -22,7 +22,8 @@ def __init__(self, name: str = 'DeeployNetwork', default_channels_first: bool = True, deeployStateDir: str = "DeeployState", - inputOffsets: Dict[str, int] = {}): + inputOffsets: Dict[str, int] = {}, + n_cores: int = 8): super().__init__(graph, deploymentPlatform, inputTypes, loweringOptimizer, scheduler, name, default_channels_first, deeployStateDir) @@ -31,6 +32,7 @@ def __init__(self, inputOffsets[key] = 0 self.inputOffsets = inputOffsets + self.n_cores = n_cores def _createIOBindings(self, ctxt, graph): ctxt = super()._createIOBindings(ctxt, graph) diff --git a/Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py b/Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py index 7ef9e96ef..89eb3ea96 100644 --- a/Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py +++ b/Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py @@ -247,7 +247,7 @@ def _NCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_channels_f if node_op in ["RequantizedConv", "Conv"]: # Non DW-Type: - if opNode.attrs['group'] == 1: + if opNode.attrs.get('group', 1) == 1: weightNode = opNode.inputs[1] weightTransposeNode, weightTransposeOutput = _appendTransposeNode(weightNode, name + "TransposeWeight", inPermute) @@ -341,7 +341,7 @@ def _PULPDWNCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_chan opNode = matched_nodes[0] node_op = opNode.op - if opNode.attrs['group'] == 1: + if opNode.attrs.get('group', 1) == 1: return graph if (("channels_first" in opNode.attrs and opNode.attrs["channels_first"] != default_channels_first) @@ -362,30 +362,67 @@ def _PULPDWNCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_chan graph.nodes.append(outputTransposeNode) if node_op == "RequantizedConv": - weightNode = opNode.inputs[1] weightTransposeNode, weightTransposeOutput = _appendTransposeNode(weightNode, name + "TransposeWeight", inPermute) opNode.inputs[1] = weightTransposeOutput graph.nodes.append(weightTransposeNode) + else: + inputTransposeNode, inputTransposeOutput = _appendTransposeNode(inputNode, name + "_TransposeIn", inPermute) + opNode.inputs[0] = inputTransposeOutput + graph.nodes.append(inputTransposeNode) opNode.attrs["channels_first"] = default_channels_first return graph +# Requantized DW Conv @contextagnostic class PULPDWConvPass(ReplaceSequentialPatternPass): def __init__(self, default_channels_first: bool = True): + # Define pattern graph graph = gs.Graph() + _input = gs.Variable(name = 'input_1') output = graph.layer(inputs = [_input], outputs = ['convOut'], op = 'RequantizedConv', name = 'requantizedConv') + graph.outputs.append(output) graph.inputs.append(_input) - name = "_NCHW_TO_NHWC_CONV_PASS" - super().__init__(graph, partial(_PULPDWNCHWtoNHWC_fun, default_channels_first = default_channels_first), name) + # Define name + name = "_NCHW_TO_NHWC_DW_CONV_PASS" + + # Initialize Pass + super().__init__(pattern = graph, + replacement_fn = partial(_PULPDWNCHWtoNHWC_fun, + default_channels_first = default_channels_first), + name = name) + + +# Float DW Conv +@contextagnostic +class PULPFPDWConvPass(ReplaceSequentialPatternPass): + + def __init__(self, default_channels_first: bool = True): + # Define pattern graph + graph = gs.Graph() + + _input = gs.Variable(name = 'input_1') + output = graph.layer(inputs = [_input], outputs = ['convOut'], op = 'Conv', name = 'conv') + + graph.outputs.append(output) + graph.inputs.append(_input) + + # Define name + name = "_NCHW_TO_NHWC_FP_DW_CONV_PASS" + + # Initialize Pass + super().__init__(pattern = graph, + replacement_fn = partial(_PULPDWNCHWtoNHWC_fun, + default_channels_first = default_channels_first), + name = name) def _PULPDenseNCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_channels_first: bool = True): @@ -465,6 +502,7 @@ def __init__(self, default_channels_first: bool = True): NCHWtoNHWCPadPass(default_channels_first), NCHWtoNHWCMaxPoolPass(default_channels_first), PULPDWConvPass(default_channels_first), + PULPFPDWConvPass(default_channels_first), PULPNCHWtoNHWCDenseConvPass(default_channels_first), PULPNCHWtoNHWCDenseRequantizedConvPass(default_channels_first), ] diff --git a/Deeploy/DeeployTypes.py b/Deeploy/DeeployTypes.py index e6ca25c9b..ff9d34e4a 100644 --- a/Deeploy/DeeployTypes.py +++ b/Deeploy/DeeployTypes.py @@ -257,7 +257,7 @@ def __init__(self, name: str = '', shape = [1], alias_of: Optional[List[str]] = self.is_input: bool = False self.is_output: bool = False - self.alias_of: List[str] = alias_of if alias_of is not None else [] + self.alias_of: List[str] = list(alias_of) if alias_of is not None else [] def _bufferRepresentation(self) -> Dict: return {"type": self._instance, "name": self.name, "size": int(np.prod(self.shape))} @@ -322,7 +322,11 @@ def __getstate__(self): @classmethod def fromNode(cls, node: gs.Node): - return (cls(name = node.name, shape = node.shape if not isinstance(node, gs.Constant) else node.values.shape)) + return (cls( + name = node.name, + shape = node.shape if not isinstance(node, gs.Constant) else node.values.shape, + alias_of = [], + )) def add_aliases(self, aliases_to_add: List[str]): """ @@ -355,7 +359,7 @@ def get_aliases_of(self): """ if hasattr(self, "alias_of"): - return self.alias_of + return list(self.alias_of) else: return list() @@ -399,7 +403,7 @@ class TransientBuffer(VariableBuffer): def __init__(self, name: str = '', size = 0): self.name = name - self.size = size #: int: Total BYTE size of this TransientBuffer + self.size = size # int: Total BYTE size # Do not override - Should be written in the parsing passes self._users = [] @@ -446,7 +450,9 @@ class ConstantBuffer(VariableBuffer): """ def __init__(self, name: str = '', shape = [1], values = [0]): + # Pass a copy of alias_of to avoid shared references super().__init__(name, shape) + values = np.asarray(values) # intArray = values.astype(int) # assert (np.abs(values - intArray)).max() < 0.001, "Constant value {name} is NOT an integer!" @@ -481,7 +487,11 @@ def _bufferRepresentation(self) -> Dict: @classmethod def fromVariableBuffer(cls, buffer: VariableBuffer, values): - ret = cls(name = buffer.name, shape = buffer.shape, values = values) + ret = cls( + name = buffer.name, + shape = buffer.shape, + values = values, + ) return ret @@ -572,7 +582,8 @@ def __init__(self, transientBuffer: Type[TransientBuffer], globalObjects = {}, localObjects = {}, - name: str = 'DeeployNetwork'): + name: str = 'DeeployNetwork', + n_cores: int = 8): self.globalObjects = OrderedDict() self.localObjects = OrderedDict() self.VariableBuffer = variableBuffer @@ -580,6 +591,7 @@ def __init__(self, self.StructBuffer = structBuffer self.TransientBuffer = transientBuffer self.name = name + self.n_cores = n_cores self._maxDynamicSize = {} #: int: Maximum dynamic memory size occupied by live buffers at any point in time self._dynamicSize = {} #: int: Current dynamic memory size occupied by live buffers @@ -874,7 +886,7 @@ def is_buffer(self, value: Any) -> bool: obj = self.lookup(value) return isinstance(obj, VariableBuffer) - def hoistTransientBuffer(self, name: str, size: int) -> str: + def hoistTransientBuffer(self, name: str, size: Union[int, str]) -> str: """Registers a new TransientBuffer in the local context Parameters @@ -1186,7 +1198,11 @@ def parseOutputs(cls, ctxt: NetworkContext, node: gs.Node) -> NetworkContext: for node, name in zip(outputNodes, outputNames): if not ctxt.is_global(name): - nb = ctxt.VariableBuffer(name = name, shape = node.shape) + nb = ctxt.VariableBuffer( + name = name, + shape = node.shape, + alias_of = [], + ) ctxt.add(nb, 'local') else: nb = ctxt.lookup(name) @@ -2487,7 +2503,8 @@ def __init__(self, inputTypes: Dict[str, Type[Pointer]], scheduler: Callable[[gs.Graph], Schedule] = lambda graph: list(graph.nodes), name: str = 'DeeployNetwork', - deeployStateDir: str = "DeeployState"): + deeployStateDir: str = "DeeployState", + n_cores: int = 8): """Initializes a new NetworkContainer and its NetworkContext Parameters @@ -2505,6 +2522,8 @@ def __init__(self, Prefix to use in deployment to uniquify tensor names deeployStateDir : str Path to a directory to dump intermediate outputs + n_cores : int + The number of cores on which the network will be run """ @@ -2523,7 +2542,8 @@ def __init__(self, self.ctxt = NetworkContext(variableBuffer = self.Platform.VariableBuffer, constantBuffer = self.Platform.ConstantBuffer, structBuffer = self.Platform.StructBuffer, - transientBuffer = self.Platform.TransientBuffer) + transientBuffer = self.Platform.TransientBuffer, + n_cores = n_cores) self.deeployStateDir = deeployStateDir @@ -2683,10 +2703,13 @@ def parse(self, default_channels_first: bool = True) -> bool: """ - self.ctxt = NetworkContext(variableBuffer = self.Platform.VariableBuffer, - constantBuffer = self.Platform.ConstantBuffer, - structBuffer = self.Platform.StructBuffer, - transientBuffer = self.Platform.TransientBuffer) + self.ctxt = NetworkContext( + variableBuffer = self.Platform.VariableBuffer, + constantBuffer = self.Platform.ConstantBuffer, + structBuffer = self.Platform.StructBuffer, + transientBuffer = self.Platform.TransientBuffer, + n_cores = self.ctxt.n_cores, + ) log.debug(" - Create IO Bindings") self.ctxt = self._createIOBindings(self.ctxt, self.graph) @@ -3232,15 +3255,18 @@ class NetworkDeployer(NetworkContainer): """Deeploy abstraction to contain an entire network and all necessary information to deploy it """ - def __init__(self, - graph: gs.Graph, - deploymentPlatform: DeploymentPlatform, - inputTypes: Dict[str, Type[Pointer]], - loweringOptimizer: TopologyOptimizer, - scheduler: Callable[[gs.Graph], Schedule] = lambda graph: list(graph.nodes), - name: str = 'DeeployNetwork', - default_channels_first: bool = True, - deeployStateDir: str = "DeeployState"): + def __init__( + self, + graph: gs.Graph, + deploymentPlatform: DeploymentPlatform, + inputTypes: Dict[str, Type[Pointer]], + loweringOptimizer: TopologyOptimizer, + scheduler: Callable[[gs.Graph], Schedule] = lambda graph: list(graph.nodes), + name: str = 'DeeployNetwork', + default_channels_first: bool = True, + deeployStateDir: str = "DeeployState", + n_cores: int = 8, + ): """Initialize a new NetworkDeployer Parameters @@ -3269,12 +3295,21 @@ def __init__(self, """ - super().__init__(graph, deploymentPlatform, inputTypes, scheduler, name, deeployStateDir = deeployStateDir) + super().__init__( + graph = graph, + platform = deploymentPlatform, + inputTypes = inputTypes, + scheduler = scheduler, + name = name, + deeployStateDir = deeployStateDir, + n_cores = n_cores, + ) self.loweringOptimizer = loweringOptimizer self.default_channels_first = default_channels_first self.prepared = False + self.n_cores = n_cores def __repr__(self): return super().__repr__( diff --git a/Deeploy/Targets/Generic/Parsers.py b/Deeploy/Targets/Generic/Parsers.py index 3c3a3472c..9c1c4aad1 100644 --- a/Deeploy/Targets/Generic/Parsers.py +++ b/Deeploy/Targets/Generic/Parsers.py @@ -1044,9 +1044,16 @@ def parseNodeCtxt(self, new_output_node_aliases = input_node.get_aliases_of() new_output_node_aliases.append(input_node.name) - # Add new aliases + # Add new aliases to output node output_node.add_aliases(aliases_to_add = new_output_node_aliases) + # Add output node as alias to its aliases (alias relationship is symmetric) + for alias in output_node.get_aliases_of(): + alias_node = ctxt.lookup(alias) + alias_node.add_aliases(aliases_to_add = [ + output_node.name, + ]) + # Compute data size self.operatorRepresentation['size'] = np.prod(ctxt.lookup(node.inputs[0].name).shape) diff --git a/Deeploy/Targets/PULPOpen/Bindings.py b/Deeploy/Targets/PULPOpen/Bindings.py index 9ff940b2f..04ddcb7a5 100644 --- a/Deeploy/Targets/PULPOpen/Bindings.py +++ b/Deeploy/Targets/PULPOpen/Bindings.py @@ -9,13 +9,13 @@ from Deeploy.CommonExtensions.CodeTransformationPasses.Closure import ClosureGeneration, MemoryAwareClosureGeneration from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration, \ MemoryManagementGeneration, MemoryPassthroughGeneration -from Deeploy.CommonExtensions.DataTypes import IntegerDataTypes, SignedIntegerDataTypes, float32_t, int8_t, int32_t, \ - uint8_t +from Deeploy.CommonExtensions.DataTypes import FloatDataTypes, IntegerDataTypes, SignedIntegerDataTypes, float32_t, \ + int8_t, int32_t, uint8_t from Deeploy.DeeployTypes import CodeTransformation, NodeBinding, NodeTemplate from Deeploy.FutureExtension.Bindings.AutoFutureBinding import AutoFutureBinding from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration -from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, DequantTemplate, FloatReduceSumTemplate, \ - GatherTemplate, QuantTemplate, RQSiGELUTemplate, iHardswishTemplate +from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, DequantTemplate, FloatReduceMeanTemplate, \ + FloatReduceSumTemplate, GatherTemplate, QuantTemplate, RQSiGELUTemplate, SliceTemplate, iHardswishTemplate from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DequantChecker, \ GatherChecker, GELUChecker, GEMMChecker, HardswishChecker, LayerNormChecker, MatMulChecker, MulChecker, \ QuantChecker, ReduceMeanChecker, ReluChecker, ReshapeChecker, RQAddChecker, RQHardswishChecker, SGDChecker, \ @@ -27,11 +27,11 @@ from Deeploy.Targets.PULPOpen.DataTypes import PULPDMAFuture from Deeploy.Targets.PULPOpen.DMA.L3Dma import l3DmaHack from Deeploy.Targets.PULPOpen.DMA.MchanDma import MchanDma -from Deeploy.Targets.PULPOpen.Templates import ConvTemplate, FloatAddTemplate, FloatConvTemplate, FloatGELUTemplate, \ - FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, \ - FloatReluTemplate, FloatSoftmaxTemplate, GEMMTemplate, MatrixVectorTemplate, MaxPool2DTemplate, MulTemplate, \ - ReduceMeanTemplate, RequantShiftTemplate, ReshapeTemplate, RQAddTemplate, RQSiHardswishTemplate, SGDTemplate, \ - SliceTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \ +from Deeploy.Targets.PULPOpen.Templates import ConvTemplate, DMASliceTemplate, FloatAddTemplate, FloatConvTemplate, \ + FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, \ + FloatMulTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GEMMTemplate, MatrixVectorTemplate, MaxPool2DTemplate, \ + MulTemplate, ReduceMeanTemplate, RequantShiftTemplate, ReshapeTemplate, RQAddTemplate, RQSiHardswishTemplate, \ + SGDTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \ iRMSNormTemplate, iSoftmaxTemplate from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \ PULPRequantShiftChecker @@ -148,10 +148,21 @@ PointerClass(uint8_t), PointerClass(uint8_t), PointerClass(uint8_t) - ], [PULPDMAFuture(underlyingType = type)]), SliceTemplate.referenceTemplate, MemoryAwareForkTransformer) + ], [PULPDMAFuture(underlyingType = type)]), DMASliceTemplate.referenceTemplate, MemoryAwareForkTransformer) for type in IntegerDataTypes ] +PULPSliceBindings = [ + NodeBinding( + SliceChecker([ + PointerClass(type), + PointerClass(uint8_t), + PointerClass(uint8_t), + PointerClass(uint8_t), + PointerClass(uint8_t) + ], [PointerClass(type)]), SliceTemplate.referenceTemplate, ForkTransformer) for type in FloatDataTypes +] + PULPReshapeBindings = [ NodeBinding(ReshapeChecker([PointerClass(type), PointerClass(int32_t)], [PointerClass(type)]), ReshapeTemplate.referenceTemplate, SkipTransformer) for type in IntegerDataTypes @@ -225,6 +236,14 @@ ForkTransformer) ] +PULPFloatDWConv2DBindings = [ + NodeBinding( + ConvChecker( + [PointerClass(float_type), PointerClass(float_type), + PointerClass(float_type)], [PointerClass(float_type)]), FloatConvTemplate.referenceDW2DIm2ColTemplate, + ForkTransformer) for float_type in FloatDataTypes +] + PULPRQSMatrixVecBindings = [ NodeBinding( PULPLinearChecker([PointerClass(type1), @@ -276,6 +295,11 @@ PULPReduceMeanBindings = [ NodeBinding(ReduceMeanChecker([PointerClass(type)], [PointerClass(type)]), ReduceMeanTemplate.referenceTemplate, ClusterTransformer) for type in IntegerDataTypes +] + [ + NodeBinding(ReduceMeanChecker([PointerClass(float_type), PointerClass(integer_type)], [PointerClass(float_type)]), + FloatReduceMeanTemplate.referenceTemplate, ClusterTransformer) + for integer_type in SignedIntegerDataTypes + for float_type in FloatDataTypes ] PULPReduceSumBindings = [ diff --git a/Deeploy/Targets/PULPOpen/Deployer.py b/Deeploy/Targets/PULPOpen/Deployer.py index 86bf02e57..c19a65812 100644 --- a/Deeploy/Targets/PULPOpen/Deployer.py +++ b/Deeploy/Targets/PULPOpen/Deployer.py @@ -37,16 +37,20 @@ def __init__(self, name: str = 'DeeployNetwork', default_channels_first = False, deeployStateDir: str = "DeeployStateDir", - inputOffsets = {}): - super().__init__(graph, - deploymentPlatform, - inputTypes, - loweringOptimizer, - scheduler, - name, - default_channels_first = default_channels_first, - deeployStateDir = deeployStateDir, - inputOffsets = inputOffsets) + inputOffsets = {}, + n_cores: int = 8): + super().__init__( + graph = graph, + deploymentPlatform = deploymentPlatform, + inputTypes = inputTypes, + loweringOptimizer = loweringOptimizer, + scheduler = scheduler, + name = name, + default_channels_first = default_channels_first, + deeployStateDir = deeployStateDir, + inputOffsets = inputOffsets, + n_cores = n_cores, + ) self.loweringOptimizer.passes += [ TransposeMatmulInputsPass(), diff --git a/Deeploy/Targets/PULPOpen/Parsers.py b/Deeploy/Targets/PULPOpen/Parsers.py index e94af6e42..cc0ff8b12 100644 --- a/Deeploy/Targets/PULPOpen/Parsers.py +++ b/Deeploy/Targets/PULPOpen/Parsers.py @@ -77,7 +77,7 @@ def parseNode(self, node: gs.Node) -> (bool): self.operatorRepresentation['pads'][0] == self.operatorRepresentation['pads'][2], self.operatorRepresentation['pads'][1] == self.operatorRepresentation['pads'][3], self.operatorRepresentation['pads'][0] == self.operatorRepresentation['pads'][1], - len(node.inputs) == 2 + len(node.inputs) in [2, 3], ]) self.operatorRepresentation['dim_kernel_x'] = int(self.operatorRepresentation['kernel_shape'][0]) @@ -102,11 +102,93 @@ def parseNodeCtxt(self, newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) if ret: + inputs = ['data_in', 'weight'] + + # Handle bias, if present + if len(node.inputs) == 2: + self.operatorRepresentation["has_bias"] = "false" + self.operatorRepresentation["bias"] = "NULL" + else: + inputs.append("bias") + self.operatorRepresentation["has_bias"] = "true" + + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + return newCtxt, True return ctxt, False +class PULPFPDWConv2DParser(Conv2DParser): + + def __init__(self, noBiasHoisting = True): + super().__init__(noBiasHoisting) + + def parseNode(self, node: gs.Node) -> (bool): + # Parse root conv 2D information + wellFormed = super().parseNode(node) + + if wellFormed: + # Check if the node is a depthwise convolution + ret = all([ + # Make sure padding is square + self.operatorRepresentation['pads'][0] == self.operatorRepresentation['pads'][2], + self.operatorRepresentation['pads'][1] == self.operatorRepresentation['pads'][3], + self.operatorRepresentation['pads'][0] == self.operatorRepresentation['pads'][1], + + # Check number of inputs + len(node.inputs) in [2, 3], + ]) + + # Extract additional attributes + self.operatorRepresentation['dim_kernel_x'] = int(self.operatorRepresentation['kernel_shape'][0]) + self.operatorRepresentation['dim_kernel_y'] = int(self.operatorRepresentation['kernel_shape'][1]) + + self.operatorRepresentation['dilation_x'] = int(self.operatorRepresentation['dilations'][0]) + self.operatorRepresentation['dilation_y'] = int(self.operatorRepresentation['dilations'][1]) + + self.operatorRepresentation['padding_y_top'] = int(self.operatorRepresentation['pads'][0]) + self.operatorRepresentation['padding_x_left'] = int(self.operatorRepresentation['pads'][1]) + self.operatorRepresentation['padding_y_bottom'] = int(self.operatorRepresentation['pads'][2]) + self.operatorRepresentation['padding_x_right'] = int(self.operatorRepresentation['pads'][3]) + + self.operatorRepresentation['stride_x'] = int(self.operatorRepresentation['strides'][0]) + self.operatorRepresentation['stride_y'] = int(self.operatorRepresentation['strides'][1]) + + return ret + return False + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + # Parse node context for 2D conv + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if ret: + # Define input names + inputs = ['data_in', 'weight'] + + # Handle bias, if present + if len(node.inputs) == 2: + self.operatorRepresentation["has_bias"] = "false" + self.operatorRepresentation["bias"] = "NULL" + else: + inputs.append("bias") + self.operatorRepresentation["has_bias"] = "true" + + # Map input nodes to operator representation + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + + # Check if DW + if self.operatorRepresentation['group'] == self.operatorRepresentation['ch_im_in']: + return newCtxt, True + + return ctxt, False + + class PULPDWConv1DParser(RQSConv1DParser): def __init__(self, noBiasHoisting = True): diff --git a/Deeploy/Targets/PULPOpen/Platform.py b/Deeploy/Targets/PULPOpen/Platform.py index 99c1c9335..9c22afec7 100644 --- a/Deeploy/Targets/PULPOpen/Platform.py +++ b/Deeploy/Targets/PULPOpen/Platform.py @@ -5,6 +5,8 @@ import numpy as np import onnx_graphsurgeon as gs +from Deeploy.CommonExtensions.OptimizationPasses.TopologyOptimizationPasses.LoweringOptimizationPasses import \ + RemoveEmptyConvBiasPass from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NetworkContext, NodeMapper, \ NodeTemplate, StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer from Deeploy.MemoryLevelExtension.MemoryLevels import MemoryHierarchy, MemoryLevel @@ -27,10 +29,11 @@ MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, QuantPatternPass, RQSSplitPass, \ SkipEmptyConcatPass, SkipUnityRequantPass, iGELURequantMergePass, iHardswishRequantMergePass from Deeploy.Targets.PULPOpen.Bindings import BasicDequantBindings, BasicQuantBindings, PULPConv1DBinding, \ - PULPDMASliceBindings, PULPDWConv1DBinding, PULPReduceMeanBindings + PULPDMASliceBindings, PULPDWConv1DBinding, PULPFloatDWConv2DBindings, PULPReduceMeanBindings, PULPSliceBindings from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \ - PULPDWConv2DParser, PULPFPConv2DParser, PULPGEMMParser, PULPMatrixVecParser, PULPTallGEMMParser + PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, \ + PULPTallGEMMParser from Deeploy.Targets.PULPOpen.Templates import AllocateTemplate, FreeTemplate from Deeploy.Targets.PULPOpen.Tiler import PULPAddTilingReadyBindings, PULPConcatTilingReadyBindings, \ PULPConv2DTilingReadyBindings, PULPFlattenTilingReadyBindings, PULPFPGELUTilingReadyBindings, \ @@ -71,6 +74,7 @@ DWConv1DMapper = NodeMapper(PULPDWConv1DParser(), [PULPDWConv1DBinding]) FPConv2DMapper = NodeMapper(PULPFPConv2DParser(), PULPConv2DTilingReadyBindings) Conv2DMapper = NodeMapper(PULPConv2DParser(), PULPRQSConv2DTilingReadyBindings) +FPDWConv2DMapper = NodeMapper(PULPFPDWConv2DParser(), PULPFloatDWConv2DBindings) DWConv2DMapper = NodeMapper(PULPDWConv2DParser(), PULPRQSDWConv2DTilingReadyBindings) GEMMMapper = NodeMapper(PULPGEMMParser(), PULPRQSGEMMTilingReadyBindings) FloatGEMMMapper = NodeMapper(GEMMParser(), PULPFPGEMMTilingReadyBindings) @@ -85,7 +89,9 @@ ConcatMapper = NodeMapper(ConcatParser(), PULPConcatTilingReadyBindings) -SliceMapper = NodeMapper(SliceParser(), PULPDMASliceBindings) +DMASliceMapper = NodeMapper(SliceParser(), PULPDMASliceBindings) + +SliceMapper = NodeMapper(SliceParser(), PULPSliceBindings) iRMSNormMapper = NodeMapper(iRMSNormParser(), PULPiRMSNormTilingReadyBindings) @@ -99,7 +105,7 @@ DequantMapper = NodeMapper(DequantParser(), BasicDequantBindings) GEMMDequantMapper = NodeMapper(PULPGEMMParser(), BasicGEMMBindings) PULPMapping = { - 'Conv': ConvLayer([FPConv2DMapper]), + 'Conv': ConvLayer([FPConv2DMapper, FPDWConv2DMapper]), 'RequantizedConv': PULPRQSConvLayer([Conv2DMapper, DWConv2DMapper, Conv1DMapper, DWConv1DMapper]), 'RequantizedGemm': PULPRQSGEMMLayer([MatrixVecMapper, TallGEMMMapper, GEMMMapper]), 'Gemm': GEMMLayer([FloatGEMMMapper, GEMMDequantMapper]), @@ -125,7 +131,7 @@ 'Squeeze': ReshapeLayer([UnsqueezeMapper]), 'Transpose': TransposeLayer([TransposeMapper]), 'Unsqueeze': ReshapeLayer([UnsqueezeMapper]), - 'Slice': SliceLayer([SliceMapper]), + 'Slice': SliceLayer([SliceMapper, DMASliceMapper]), 'RequantizedAdd': AddLayer([RQAddMapper]), 'Concat': ConcatLayer([ConcatMapper]), 'iRMSNorm': iRMSNormLayer([iRMSNormMapper]), @@ -225,9 +231,9 @@ class PULPStructBuffer(StructBuffer): MergeConstAddAndRequantPass(), PULPGEMMRequantMergePass(), PULPMatMulRequantMergePass(), - PULPAddRequantMergePass() -], - name = "PULPOptimizer") + PULPAddRequantMergePass(), + RemoveEmptyConvBiasPass(), +]) # SCHEREMO: stdint is included before pulp_nn_kernels.h because it is supposed to be included in there, but isn't... _includeList = [ diff --git a/Deeploy/Targets/PULPOpen/Templates/SliceTemplate.py b/Deeploy/Targets/PULPOpen/Templates/DMASliceTemplate.py similarity index 100% rename from Deeploy/Targets/PULPOpen/Templates/SliceTemplate.py rename to Deeploy/Targets/PULPOpen/Templates/DMASliceTemplate.py diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatAddTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatAddTemplate.py index 7f1c2e21c..200ad1b9e 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatAddTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatAddTemplate.py @@ -6,14 +6,14 @@ referenceTemplate = NodeTemplate(""" // Add Parallel with 1x6 unrolling (Name: ${nodeName}, Op: ${nodeOp}) -int8_t ${nodeName}_core_id = pi_core_id(); -int8_t ${nodeName}_log2Core = log2(NUM_CORES); -int16_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0); -int16_t ${nodeName}_chunk_start = MIN(${nodeName}_chunk*${nodeName}_core_id, ${size}); -int16_t ${nodeName}_chunk_stop = MIN(${nodeName}_chunk_start + ${nodeName}_chunk, ${size}); +uint8_t ${nodeName}_core_id = (uint8_t) pi_core_id(); +uint8_t ${nodeName}_log2Core = (uint8_t) log2(NUM_CORES); +uint32_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0); +uint32_t ${nodeName}_chunk_start = (uint32_t) MIN(${nodeName}_chunk*${nodeName}_core_id, (uint32_t) ${size}); +uint32_t ${nodeName}_chunk_stop = (uint32_t) MIN(${nodeName}_chunk_start + ${nodeName}_chunk, (uint32_t) ${size}); uint32_t i = ${nodeName}_chunk_start; -for (; i+5 < ${nodeName}_chunk_stop; i+=6) { +for (; i + 5 < ${nodeName}_chunk_stop; i += 6) { ${data_out}[i] = ${data_in_1}[i] + ${data_in_2}[i]; ${data_out}[i+1] = ${data_in_1}[i+1] + ${data_in_2}[i+1]; ${data_out}[i+2] = ${data_in_1}[i+2] + ${data_in_2}[i+2]; diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatConvTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatConvTemplate.py index 29a216d72..c15cfba0c 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatConvTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatConvTemplate.py @@ -18,8 +18,8 @@ def __init__(self, templateStr): def computeTransientBuffersSize( ctxt: NetworkContext, operatorRepresentation: OperatorRepresentation) -> List[Tuple[str, Union[int, IntVar]]]: - im2col_dim = 4 * 8 * (operatorRepresentation['ch_im_in'] * operatorRepresentation['dim_kernel_x'] * - operatorRepresentation['dim_kernel_y']) + im2col_dim = (operatorRepresentation["weight_type"].typeWidth // 8) * ctxt.n_cores * operatorRepresentation[ + 'ch_im_in'] * operatorRepresentation['dim_kernel_x'] * operatorRepresentation['dim_kernel_y'] im2col_name = operatorRepresentation['nodeName'] + "_buffer" return [(im2col_name, im2col_dim)] @@ -34,6 +34,37 @@ def hoistTransientBuffers(self, ctxt: NetworkContext, return ctxt, operatorRepresentation, [im2col_name] +class PULP2DFloatDWConvIm2ColTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + @staticmethod + def computeTransientBuffersSize(ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> List[Tuple[str, str]]: + + # Memory allocation for the im2col buffer can be dynamic, based on the number of cores. + # WARNING: This works because value is only used as string, in the allocate template. + im2col_dim = (operatorRepresentation["weight_type"].typeWidth // 8 + ) * ctxt.n_cores * operatorRepresentation['dim_kernel_x'] * operatorRepresentation['dim_kernel_y'] + im2col_name = operatorRepresentation['nodeName'] + "_buffer" + return [(im2col_name, im2col_dim)] + + def hoistTransientBuffers(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + im2col_name, im2col_dim = PULP2DFloatDWConvIm2ColTemplate.computeTransientBuffersSize( + ctxt, operatorRepresentation)[0] + ctxt.hoistTransientBuffer(im2col_name, im2col_dim) + + # Manually set the type of the im2col buffer to match the input type, since it defaults to void for transient buffers + ctxt.lookup(im2col_name)._type.referencedType = ctxt.lookup( + operatorRepresentation['data_in'])._type.referencedType + + operatorRepresentation['ctxtBuffer'] = im2col_name + operatorRepresentation['ctxtBufferSize'] = im2col_dim + return ctxt, operatorRepresentation, [im2col_name] + + reference2DTemplate = NodeTemplate(""" // 2D FP Conv HWC with ChannelOut parallelism (Name: ${nodeName}, Op: ${nodeOp}) @@ -47,6 +78,7 @@ def hoistTransientBuffers(self, ctxt: NetworkContext, ${weight}, ${ch_im_out}, ${dim_kernel_y}, ${dim_kernel_x}, ${stride_y}, ${stride_x}, + ${bias}, ${has_bias}, ref_${data_out}_${data_out}, ${padding_y_top}, ${padding_y_bottom}, ${padding_x_left}, ${padding_x_right} ); @@ -65,16 +97,49 @@ def hoistTransientBuffers(self, ctxt: NetworkContext, for (uint32_t n=0; n<${batch}; ++n) { PULP_Conv2d_Im2Col_fp${data_in_type.referencedType.typeWidth}_fp${weight_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}_HWC( - ref_${data_out}_${data_in}, - ${dim_im_in_y}, - ${dim_im_in_x}, - ${ch_im_in}, - ${weight}, - ${ch_im_out}, - ${dim_kernel_y}, - ${dim_kernel_x}, + ref_${data_out}_${data_in}, + ${dim_im_in_x}, + ${dim_im_in_y}, + ${ch_im_in}, + ${weight}, + ${ch_im_out}, + ${dim_kernel_x}, + ${dim_kernel_y}, + ${stride_x}, + ${stride_y}, + ${bias}, ${has_bias}, + ref_${data_out}_${data_out}, + ${padding_y_top}, + ${padding_y_bottom}, + ${padding_x_left}, + ${padding_x_right}, + ${ctxtBuffer} + ); + + ref_${data_out}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y}; + ref_${data_out}_${data_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y}; +} +""") + +referenceDW2DIm2ColTemplate = PULP2DFloatDWConvIm2ColTemplate(""" +// 2D DW FP Conv HWC with Im2Col and ChannelOout parallelism (Name: ${nodeName}, Op: ${nodeOp}) + +${data_in_type.typeName} ref_${data_out}_${data_in} = ${data_in}; +${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out}; + +for (uint32_t n=0; n<${batch}; ++n) { + PULP_DW_Conv2d_Im2Col_fp${data_in_type.referencedType.typeWidth}_fp${weight_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}_HWC( + ref_${data_out}_${data_in}, + ${dim_im_in_x}, + ${dim_im_in_y}, + ${ch_im_in}, + ${weight}, + ${ch_im_out}, + ${dim_kernel_x}, + ${dim_kernel_y}, + ${stride_x}, ${stride_y}, - ${stride_x}, + ${bias}, ${has_bias}, ref_${data_out}_${data_out}, ${padding_y_top}, ${padding_y_bottom}, @@ -86,4 +151,4 @@ def hoistTransientBuffers(self, ctxt: NetworkContext, ref_${data_out}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y}; ref_${data_out}_${data_out} += ${ch_im_out} * ${dim_im_out_x} * ${dim_im_out_y}; } -""") \ No newline at end of file +""") diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py index f4c22b2c2..ab39d266c 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py @@ -23,10 +23,19 @@ ${transA}, ${transB} ); - + + % if A_batched: ref_${data_out}_${A} += ${M} * ${N}; + % endif + + % if B_batched: ref_${data_out}_${B} += ${N} * ${O}; + % endif + + % if C_batched: ref_${data_out}_${C} += ${M} * ${O}; + % endif + ref_${data_out}_${data_out} += ${M} * ${O}; } """) \ No newline at end of file diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatMatMulTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatMatMulTemplate.py index 11b7c9aa2..158abe5ab 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatMatMulTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatMatMulTemplate.py @@ -8,8 +8,18 @@ // Matmul with row parallelism (Name: ${nodeName}, Op: ${nodeOp}) for(uint32_t b=0; b<${batch}; b++) { + % if A_batched: ${A_type.typeName} batch_A = ${A} + b * ${M} * ${N}; + % else: + ${A_type.typeName} batch_A = ${A}; + % endif + + % if B_batched: ${B_type.typeName} batch_B = ${B} + b * ${N} * ${O}; + % else: + ${B_type.typeName} batch_B = ${B}; + % endif + ${data_out_type.typeName} batch_out = ${data_out} + b * ${M} * ${O}; PULP_MatMul_fp32_fp32_fp32_unroll1x7( diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatMulTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatMulTemplate.py index 2f202b24d..ced6c3cbc 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatMulTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatMulTemplate.py @@ -7,11 +7,11 @@ referenceTemplate = NodeTemplate(""" // Float Mul with parallelism and 6x unrolling (Name: ${nodeName}, Op: ${nodeOp}) -int8_t ${nodeName}_core_id = pi_core_id(); -int8_t ${nodeName}_log2Core = log2(NUM_CORES); +uint32_t ${nodeName}_core_id = pi_core_id(); +uint32_t ${nodeName}_log2Core = (uint32_t) log2(NUM_CORES); uint32_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1)) != 0); -uint32_t ${nodeName}_start = MIN(${nodeName}_chunk * ${nodeName}_core_id, ${size}); -uint32_t ${nodeName}_end = MIN(${nodeName}_start + ${nodeName}_chunk, ${size}); +uint32_t ${nodeName}_start = MIN(${nodeName}_chunk * ${nodeName}_core_id, (uint32_t) ${size}); +uint32_t ${nodeName}_end = MIN(${nodeName}_start + ${nodeName}_chunk, (uint32_t) ${size}); if (${nodeName}_start < ${nodeName}_end) { float32_t ${nodeName}_scalar = ${B}[0]; diff --git a/Deeploy/TilingExtension/TilerExtension.py b/Deeploy/TilingExtension/TilerExtension.py index bdae0fbdc..e44c7dfff 100644 --- a/Deeploy/TilingExtension/TilerExtension.py +++ b/Deeploy/TilingExtension/TilerExtension.py @@ -985,6 +985,7 @@ def tile(self, tilingSolution: Optional[TilingSolution] = None, memoryMap: Optio self.tiler.annotateMemoryLevel(self.ctxt, tilingSolution, memoryMap) self.ctxt = self.tiler._convertCtxtToStaticSchedule(self.ctxt, memoryMap) + self.ctxt.n_cores = self.n_cores if self.tiler.visualizeMemoryAlloc: log.info(f" > Export Memory Allocation Visualization to {self.deeployStateDir}") diff --git a/DeeployTest/Tests/testFloatReshapeWithSkipConnection/inputs.npz b/DeeployTest/Tests/testFloatReshapeWithSkipConnection/inputs.npz index a98a6c33b..36567a96c 100644 Binary files a/DeeployTest/Tests/testFloatReshapeWithSkipConnection/inputs.npz and b/DeeployTest/Tests/testFloatReshapeWithSkipConnection/inputs.npz differ diff --git a/DeeployTest/Tests/testFloatReshapeWithSkipConnection/network.onnx b/DeeployTest/Tests/testFloatReshapeWithSkipConnection/network.onnx index ae1b3ac93..5eb3ae446 100644 Binary files a/DeeployTest/Tests/testFloatReshapeWithSkipConnection/network.onnx and b/DeeployTest/Tests/testFloatReshapeWithSkipConnection/network.onnx differ diff --git a/DeeployTest/Tests/testFloatReshapeWithSkipConnection/outputs.npz b/DeeployTest/Tests/testFloatReshapeWithSkipConnection/outputs.npz index a5d4b6e97..0e2e55fcf 100644 Binary files a/DeeployTest/Tests/testFloatReshapeWithSkipConnection/outputs.npz and b/DeeployTest/Tests/testFloatReshapeWithSkipConnection/outputs.npz differ diff --git a/DeeployTest/testMVP.py b/DeeployTest/testMVP.py index 013f854da..9e2294c52 100644 --- a/DeeployTest/testMVP.py +++ b/DeeployTest/testMVP.py @@ -55,8 +55,8 @@ def _filterSchedule(schedule: List[List[gs.Node]], layerBinding: 'OrderedDict[st def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTargetMemoryLevel: MemoryLevel, - defaultIoMemoryLevel: MemoryLevel, verbose: CodeGenVerbosity, - args: argparse.Namespace) -> Tuple[NetworkDeployer, bool]: + defaultIoMemoryLevel: MemoryLevel, verbose: CodeGenVerbosity, args: argparse.Namespace, + n_cores: int) -> Tuple[NetworkDeployer, bool]: inputTypes = {} inputOffsets = {} @@ -81,12 +81,15 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg inputTypes[f"input_{index}"] = _type inputOffsets[f"input_{index}"] = offset - deployer = mapDeployer(platform, - graph, - inputTypes, - deeployStateDir = _DEEPLOYSTATEDIR, - inputOffsets = inputOffsets, - scheduler = _mockScheduler) + deployer = mapDeployer( + platform, + graph, + inputTypes, + deeployStateDir = _DEEPLOYSTATEDIR, + inputOffsets = inputOffsets, + scheduler = _mockScheduler, + n_cores = n_cores, + ) # Make the deployer engine-color-aware if args.platform == "Siracusa_w_neureka": @@ -195,6 +198,11 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg parser.add_argument('--plotMemAlloc', action = 'store_true', help = 'Turn on plotting of the memory allocation and save it in the deeployState folder\n') + parser.add_argument( + "-n_cores", + type = int, + default = 8, + help = "Number of cores to target in the tiling. Currently, required for im2col buffer sizing. Default: 8") parser.set_defaults(shouldFail = False) args = parser.parse_args() @@ -250,7 +258,8 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg defaultTargetMemoryLevel = L1, defaultIoMemoryLevel = memoryHierarchy.memoryLevels[args.defaultMemLevel], verbose = verbosityCfg, - args = args) + args = args, + n_cores = args.n_cores) platform = deployer.Platform diff --git a/DeeployTest/testRunner_tiled_siracusa.py b/DeeployTest/testRunner_tiled_siracusa.py index 7bf08b7b2..20a35dd9b 100644 --- a/DeeployTest/testRunner_tiled_siracusa.py +++ b/DeeployTest/testRunner_tiled_siracusa.py @@ -17,7 +17,13 @@ help = 'Set number of cluster cores') args = parser.parse_args() - testRunner = TestRunner(platform = "Siracusa", simulator = "gvsoc", tiling = True, argument_parser = parser) + testRunner = TestRunner( + platform = "Siracusa", + simulator = "gvsoc", + tiling = True, + argument_parser = parser, + cores = args.cores, + ) testRunner.cmake_args += f" -D NUM_CORES={args.cores}" diff --git a/DeeployTest/testUtils/platformMapping.py b/DeeployTest/testUtils/platformMapping.py index 48c577790..31bf2176c 100644 --- a/DeeployTest/testUtils/platformMapping.py +++ b/DeeployTest/testUtils/platformMapping.py @@ -89,15 +89,18 @@ def setupMemoryPlatform(platform: DeploymentPlatform, memoryHierarchy: MemoryHie return MemoryPlatformWrapper(platform, memoryHierarchy, defaultTargetMemoryLevel) -def mapDeployer(platform: DeploymentPlatform, - graph: gs.Graph, - inputTypes: Dict[str, Type[Pointer]], - loweringOptimizer: Optional[TopologyOptimizer] = None, - scheduler: Optional[Callable] = None, - name: Optional[str] = None, - default_channels_first: Optional[bool] = None, - deeployStateDir: Optional[str] = None, - inputOffsets: Optional[Dict[str, int]] = None) -> NetworkDeployer: +def mapDeployer( + platform: DeploymentPlatform, + graph: gs.Graph, + inputTypes: Dict[str, Type[Pointer]], + loweringOptimizer: Optional[TopologyOptimizer] = None, + scheduler: Optional[Callable] = None, + name: Optional[str] = None, + default_channels_first: Optional[bool] = None, + deeployStateDir: Optional[str] = None, + inputOffsets: Optional[Dict[str, int]] = None, + n_cores: Optional[int] = 8, +) -> NetworkDeployer: if scheduler is None: scheduler = defaultScheduler @@ -208,14 +211,17 @@ def mapDeployer(platform: DeploymentPlatform, if default_channels_first is None: default_channels_first = False - deployer = PULPDeployer(graph, - platform, - inputTypes, - loweringOptimizer, - scheduler, - name = name, - default_channels_first = default_channels_first, - deeployStateDir = deeployStateDir) + deployer = PULPDeployer( + graph, + platform, + inputTypes, + loweringOptimizer, + scheduler, + name = name, + default_channels_first = default_channels_first, + deeployStateDir = deeployStateDir, + n_cores = n_cores, + ) elif isinstance(platform, (SnitchPlatform)): if loweringOptimizer is None: diff --git a/DeeployTest/testUtils/testRunner.py b/DeeployTest/testUtils/testRunner.py index 7d1f7f312..1f186902c 100644 --- a/DeeployTest/testUtils/testRunner.py +++ b/DeeployTest/testUtils/testRunner.py @@ -278,13 +278,16 @@ def cmake_args(self) -> str: class TestRunner(): - def __init__(self, - platform: str, - simulator: Literal['gvsoc', 'banshee', 'qemu', 'vsim', 'vsim.gui', 'host', 'none'], - tiling: bool, - argument_parser: TestRunnerArgumentParser, - gen_args: str = "", - cmake_args: str = ""): + def __init__( + self, + platform: str, + simulator: Literal['gvsoc', 'banshee', 'qemu', 'vsim', 'vsim.gui', 'host', 'none'], + tiling: bool, + argument_parser: TestRunnerArgumentParser, + gen_args: str = "", + cmake_args: str = "", + cores: int = 8, + ): if simulator not in ['gvsoc', 'banshee', 'qemu', 'vsim', 'vsim.gui', 'host', 'none']: raise ValueError( @@ -304,6 +307,8 @@ def __init__(self, self.cmake_args = cmake_args self.gen_args = gen_args + self.n_cores = cores + self._dir_gen_root = f'TEST_{platform.upper()}' assert self._args.toolchain_install_dir is not None, f"Environment variable LLVM_INSTALL_DIR is not set" self._dir_toolchain = os.path.normpath(self._args.toolchain_install_dir) @@ -342,6 +347,13 @@ def generate_test(self): generation_script = "generateNetwork.py" command = f"python {generation_script} -d {self._dir_gen} -t {self._dir_test} -p {self._platform} {self.gen_args}" + + command = f"python {generation_script} -d {self._dir_gen} -t {self._dir_test} -p {self._platform}" + + if self._tiling is True: + command += f" -n_cores {self.n_cores}" + + command += f" {self.gen_args}" command += self._argument_parser.generate_cmd_args() log.debug(f"[TestRunner] Generation Command: {command}") diff --git a/TargetLibraries/PULPOpen/inc/kernel/Conv.h b/TargetLibraries/PULPOpen/inc/kernel/Conv.h index f5382a339..3ebab54a0 100644 --- a/TargetLibraries/PULPOpen/inc/kernel/Conv.h +++ b/TargetLibraries/PULPOpen/inc/kernel/Conv.h @@ -9,20 +9,30 @@ #include "DeeployPULPMath.h" -void PULP_Conv2d_fp32_fp32_fp32_HWC(const float32_t *__restrict__ pSrcA, - uint32_t H, uint32_t W, uint32_t C, - const float32_t *__restrict__ pSrcB, - uint32_t F_total, uint32_t P, uint32_t Q, - uint32_t SP, uint32_t SQ, - float32_t *__restrict__ pDstC, - uint32_t pad_top, uint32_t pad_bottom, - uint32_t pad_left, uint32_t pad_right); +void PULP_Conv2d_fp32_fp32_fp32_HWC( + const float32_t *__restrict__ pSrcA, uint32_t H, uint32_t W, uint32_t C, + const float32_t *__restrict__ pSrcB, uint32_t F_total, uint32_t P, + uint32_t Q, uint32_t SP, uint32_t SQ, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, uint32_t pad_top, uint32_t pad_bottom, + uint32_t pad_left, uint32_t pad_right); void PULP_Conv2d_Im2Col_fp32_fp32_fp32_HWC( const float32_t *__restrict__ pSrcA, uint32_t H, uint32_t W, uint32_t C, const float32_t *__restrict__ pSrcB, uint32_t F_total, uint32_t P, - uint32_t Q, uint32_t SP, uint32_t SQ, float32_t *__restrict__ pDstC, - uint32_t pad_top, uint32_t pad_bottom, uint32_t pad_left, - uint32_t pad_right, float32_t *__restrict__ pContextBuffer); + uint32_t Q, uint32_t SP, uint32_t SQ, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, uint32_t pad_top, uint32_t pad_bottom, + uint32_t pad_left, uint32_t pad_right, + float32_t *__restrict__ pContextBuffer); + +void PULP_DW_Conv2d_Im2Col_fp32_fp32_fp32_HWC( + const float32_t *__restrict__ pSrcA, uint32_t H, uint32_t W, uint32_t C, + const float32_t *__restrict__ pSrcB, uint32_t F_total, uint32_t P, + uint32_t Q, uint32_t SP, uint32_t SQ, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, uint32_t pad_top, uint32_t pad_bottom, + uint32_t pad_left, uint32_t pad_right, + float32_t *__restrict__ pContextBuffer); #endif // __DEEPLOY_MATH_CONV_KERNEL_HEADER_ \ No newline at end of file diff --git a/TargetLibraries/PULPOpen/src/Convolution_fp32.c b/TargetLibraries/PULPOpen/src/Convolution_fp32.c index c33ac31e8..af2129323 100644 --- a/TargetLibraries/PULPOpen/src/Convolution_fp32.c +++ b/TargetLibraries/PULPOpen/src/Convolution_fp32.c @@ -7,18 +7,19 @@ #include "DeeployPULPMath.h" #include "pmsis.h" -void PULP_Conv2d_fp32_fp32_fp32_HWC(const float32_t *__restrict__ pSrcA, - uint32_t H, uint32_t W, uint32_t C, - const float32_t *__restrict__ pSrcB, - uint32_t F_total, uint32_t P, uint32_t Q, - uint32_t SP, uint32_t SQ, - float32_t *__restrict__ pDstC, - uint32_t pad_top, uint32_t pad_bottom, - uint32_t pad_left, uint32_t pad_right) { +void PULP_Conv2d_fp32_fp32_fp32_HWC( + const float32_t *__restrict__ pSrcA, uint32_t H, uint32_t W, uint32_t C, + const float32_t *__restrict__ pSrcB, uint32_t F_total, uint32_t P, + uint32_t Q, uint32_t SP, uint32_t SQ, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, uint32_t pad_top, uint32_t pad_bottom, + uint32_t pad_left, uint32_t pad_right) { + // Compute core int8_t core_id = pi_core_id(); int8_t log2Core = LOG2(NUM_CORES); + // Compute the chunk size for each core uint16_t ch_out_chunk = (F_total >> log2Core) + ((F_total & (NUM_CORES - 1)) != 0); uint16_t ch_out_start = MIN(ch_out_chunk * core_id, F_total); @@ -29,37 +30,72 @@ void PULP_Conv2d_fp32_fp32_fp32_HWC(const float32_t *__restrict__ pSrcA, return; } + // Pointer to the weights for the current core const float32_t *weight_ptr = pSrcB + ch_out_start * C * P * Q; + // Compute the output dimensions uint32_t H_out = (H + pad_top + pad_bottom - P) / SP + 1; uint32_t W_out = (W + pad_left + pad_right - Q) / SQ + 1; - for (uint32_t h = 0; h < H_out; ++h) { - for (uint32_t w = 0; w < W_out; ++w) { - for (uint32_t f = 0; f < ch_out_count; ++f) { - float32_t sum = 0.0f; + // Compute the output + if (has_bias) { + for (uint32_t h = 0; h < H_out; ++h) { + for (uint32_t w = 0; w < W_out; ++w) { + for (uint32_t f = 0; f < ch_out_count; ++f) { + float32_t sum = 0.0f; - for (uint32_t p = 0; p < P; ++p) { - for (uint32_t q = 0; q < Q; ++q) { - for (uint32_t c = 0; c < C; ++c) { - int32_t h_in = h * SP + p - pad_top; - int32_t w_in = w * SQ + q - pad_left; + for (uint32_t p = 0; p < P; ++p) { + for (uint32_t q = 0; q < Q; ++q) { + for (uint32_t c = 0; c < C; ++c) { + int32_t h_in = h * SP + p - pad_top; + int32_t w_in = w * SQ + q - pad_left; - if (h_in < 0 || h_in >= (int32_t)H || w_in < 0 || - w_in >= (int32_t)W) { - continue; - } + if (h_in < 0 || h_in >= (int32_t)H || w_in < 0 || + w_in >= (int32_t)W) { + continue; + } - uint32_t input_idx = (h_in * W + w_in) * C + c; - uint32_t weight_idx = f * (P * Q * C) + p * (Q * C) + q * C + c; + uint32_t input_idx = (h_in * W + w_in) * C + c; + uint32_t weight_idx = f * (P * Q * C) + p * (Q * C) + q * C + c; - sum += pSrcA[input_idx] * weight_ptr[weight_idx]; + sum += pSrcA[input_idx] * weight_ptr[weight_idx]; + } } } + + uint32_t output_idx = (h * W_out + w) * F_total + (ch_out_start + f); + pDstC[output_idx] = sum + pSrcBias[f + ch_out_start]; } + } + } + } else { + for (uint32_t h = 0; h < H_out; ++h) { + for (uint32_t w = 0; w < W_out; ++w) { + for (uint32_t f = 0; f < ch_out_count; ++f) { + float32_t sum = 0.0f; + + for (uint32_t p = 0; p < P; ++p) { + for (uint32_t q = 0; q < Q; ++q) { + for (uint32_t c = 0; c < C; ++c) { + int32_t h_in = h * SP + p - pad_top; + int32_t w_in = w * SQ + q - pad_left; + + if (h_in < 0 || h_in >= (int32_t)H || w_in < 0 || + w_in >= (int32_t)W) { + continue; + } + + uint32_t input_idx = (h_in * W + w_in) * C + c; + uint32_t weight_idx = f * (P * Q * C) + p * (Q * C) + q * C + c; + + sum += pSrcA[input_idx] * weight_ptr[weight_idx]; + } + } + } - uint32_t output_idx = (h * W_out + w) * F_total + (ch_out_start + f); - pDstC[output_idx] = sum; + uint32_t output_idx = (h * W_out + w) * F_total + (ch_out_start + f); + pDstC[output_idx] = sum; + } } } } @@ -68,12 +104,17 @@ void PULP_Conv2d_fp32_fp32_fp32_HWC(const float32_t *__restrict__ pSrcA, void PULP_Conv2d_Im2Col_fp32_fp32_fp32_HWC( const float32_t *__restrict__ pSrcA, uint32_t H, uint32_t W, uint32_t C, const float32_t *__restrict__ pSrcB, uint32_t F_total, uint32_t P, - uint32_t Q, uint32_t SP, uint32_t SQ, float32_t *__restrict__ pDstC, - uint32_t pad_top, uint32_t pad_bottom, uint32_t pad_left, - uint32_t pad_right, float32_t *__restrict__ pContextBuffer) { + uint32_t Q, uint32_t SP, uint32_t SQ, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, uint32_t pad_top, uint32_t pad_bottom, + uint32_t pad_left, uint32_t pad_right, + float32_t *__restrict__ pContextBuffer) { + + // Compute core int8_t core_id = pi_core_id(); int8_t log2Core = LOG2(NUM_CORES); + // Compute the chunk size for each core uint16_t ch_out_chunk = (F_total >> log2Core) + ((F_total & (NUM_CORES - 1)) != 0); uint16_t ch_out_start = MIN(ch_out_chunk * core_id, F_total); @@ -84,50 +125,95 @@ void PULP_Conv2d_Im2Col_fp32_fp32_fp32_HWC( return; } + // Pointer to the weights for the current core const float32_t *weight_ptr = pSrcB + ch_out_start * C * P * Q; uint32_t im2col_size_per_core = C * P * Q; float32_t *im2col_buffer = pContextBuffer + core_id * im2col_size_per_core; + // Compute the output dimensions uint32_t H_out = (H + pad_top + pad_bottom - P) / SP + 1; uint32_t W_out = (W + pad_left + pad_right - Q) / SQ + 1; uint32_t kernel_size = P * Q * C; - for (uint32_t h_out = 0; h_out < H_out; h_out++) { - for (uint32_t w_out = 0; w_out < W_out; w_out++) { - int32_t h_in_start = h_out * SP - pad_top; - int32_t w_in_start = w_out * SQ - pad_left; + // Compute the output + if (has_bias) { + for (uint32_t h_out = 0; h_out < H_out; h_out++) { + for (uint32_t w_out = 0; w_out < W_out; w_out++) { + int32_t h_in_start = h_out * SP - pad_top; + int32_t w_in_start = w_out * SQ - pad_left; + + for (uint32_t p = 0; p < P; p++) { + int32_t h_in = h_in_start + p; + + for (uint32_t q = 0; q < Q; q++) { + int32_t w_in = w_in_start + q; + + for (uint32_t c = 0; c < C; c++) { + if (h_in >= 0 && h_in < (int32_t)H && w_in >= 0 && + w_in < (int32_t)W) { + uint32_t in_idx = (h_in * W + w_in) * C + c; + im2col_buffer[p * Q * C + q * C + c] = pSrcA[in_idx]; + } else { + im2col_buffer[p * Q * C + q * C + c] = 0.0f; + } + } + } + } + + for (uint32_t f = ch_out_start; f < ch_out_stop; f++) { + float32_t sum = 0.0f; + const float32_t *local_weight_ptr = + weight_ptr + (f - ch_out_start) * kernel_size; - for (uint32_t p = 0; p < P; p++) { - int32_t h_in = h_in_start + p; + for (uint32_t k = 0; k < kernel_size; k++) { + sum += im2col_buffer[k] * local_weight_ptr[k]; + } - for (uint32_t q = 0; q < Q; q++) { - int32_t w_in = w_in_start + q; + uint32_t out_idx = (h_out * W_out + w_out) * F_total + f; - for (uint32_t c = 0; c < C; c++) { - if (h_in >= 0 && h_in < (int32_t)H && w_in >= 0 && - w_in < (int32_t)W) { - uint32_t in_idx = (h_in * W + w_in) * C + c; - im2col_buffer[p * Q * C + q * C + c] = pSrcA[in_idx]; - } else { - im2col_buffer[p * Q * C + q * C + c] = 0.0f; + pDstC[out_idx] = sum + pSrcBias[f]; + } + } + } + } else { + for (uint32_t h_out = 0; h_out < H_out; h_out++) { + for (uint32_t w_out = 0; w_out < W_out; w_out++) { + int32_t h_in_start = h_out * SP - pad_top; + int32_t w_in_start = w_out * SQ - pad_left; + + for (uint32_t p = 0; p < P; p++) { + int32_t h_in = h_in_start + p; + + for (uint32_t q = 0; q < Q; q++) { + int32_t w_in = w_in_start + q; + + for (uint32_t c = 0; c < C; c++) { + if (h_in >= 0 && h_in < (int32_t)H && w_in >= 0 && + w_in < (int32_t)W) { + uint32_t in_idx = (h_in * W + w_in) * C + c; + im2col_buffer[p * Q * C + q * C + c] = pSrcA[in_idx]; + } else { + im2col_buffer[p * Q * C + q * C + c] = 0.0f; + } } } } - } - for (uint32_t f = 0; f < ch_out_count; f++) { - float32_t sum = 0.0f; - const float32_t *local_weight_ptr = weight_ptr + f * kernel_size; + for (uint32_t f = ch_out_start; f < ch_out_stop; f++) { + float32_t sum = 0.0f; + const float32_t *local_weight_ptr = + weight_ptr + (f - ch_out_start) * kernel_size; - for (uint32_t k = 0; k < kernel_size; k++) { - sum += im2col_buffer[k] * local_weight_ptr[k]; - } + for (uint32_t k = 0; k < kernel_size; k++) { + sum += im2col_buffer[k] * local_weight_ptr[k]; + } - uint32_t out_idx = - (h_out * W_out + w_out) * F_total + (ch_out_start + f); - pDstC[out_idx] = sum; + uint32_t out_idx = (h_out * W_out + w_out) * F_total + f; + + pDstC[out_idx] = sum; + } } } } -} \ No newline at end of file +} diff --git a/TargetLibraries/PULPOpen/src/DWConvolution_fp32.c b/TargetLibraries/PULPOpen/src/DWConvolution_fp32.c new file mode 100644 index 000000000..b0a06c66e --- /dev/null +++ b/TargetLibraries/PULPOpen/src/DWConvolution_fp32.c @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "DeeployPULPMath.h" +#include "pmsis.h" + +void PULP_DW_Conv2d_Im2Col_fp32_fp32_fp32_HWC( + const float32_t *__restrict__ pSrcA, uint32_t H, uint32_t W, uint32_t C, + const float32_t *__restrict__ pSrcB, uint32_t F_total, uint32_t P, + uint32_t Q, uint32_t SP, uint32_t SQ, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, uint32_t pad_top, uint32_t pad_bottom, + uint32_t pad_left, uint32_t pad_right, + float32_t *__restrict__ pContextBuffer) { + + // Compute core information + int8_t core_id = pi_core_id(); + int8_t log2Core = log2(NUM_CORES); + + // Compute the chunk size for each core + // (Splitting work along the output channels) + uint16_t ch_out_chunk = + (F_total >> log2Core) + ((F_total & (NUM_CORES - 1)) != 0); + uint16_t ch_out_start = MIN(ch_out_chunk * core_id, F_total); + uint16_t ch_out_stop = MIN(ch_out_start + ch_out_chunk, F_total); + uint16_t ch_out_count = ch_out_stop - ch_out_start; + + // If there is no output channel to process, return + // (when F < NUM_CORES and working on a core with id > F) + if (ch_out_count == 0) { + return; + } + + // Move pointer of the weights for the current core + const float32_t *weight_ptr = pSrcB + ch_out_start * P * Q; + + // Move pointer of the im2col buffer for the current core + uint32_t im2col_size_per_core = P * Q; + float32_t *im2col_buffer = pContextBuffer + core_id * im2col_size_per_core; + + // Compute the output dimensions + uint32_t H_out = (H + pad_top + pad_bottom - P) / SP + 1; + uint32_t W_out = (W + pad_left + pad_right - Q) / SQ + 1; + uint32_t kernel_size = P * Q * F_total; + + // Compute the output + if (has_bias) { + // Work on individual output elements + // (each element depends on a column from the im2col buffer + // and one convolutional filter, stored in memory continuously) + for (uint32_t h_out = 0; h_out < H_out; h_out++) { + for (uint32_t w_out = 0; w_out < W_out; w_out++) { + // Compute height and width starting point + // (depending on stride and padding) + int32_t h_in_start = h_out * SP - pad_top; + int32_t w_in_start = w_out * SQ - pad_left; + + // Initialize the padded part of the im2col buffer with 0 + // Work on the TOP padding + for (int32_t h_in = (int32_t)h_in_start; + h_in < MIN(0, (int32_t)(h_in_start + P)); h_in++) { + for (int32_t w_in = (int32_t)w_in_start; + w_in < (int32_t)(w_in_start + Q); w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Work on the BOTTOM padding + for (uint32_t h_in = MAX(H, h_in_start); h_in < h_in_start + P; + h_in++) { + for (int32_t w_in = (int32_t)w_in_start; + w_in < (int32_t)(w_in_start + Q); w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Work on the remaining LEFT padding + for (uint32_t h_in = MAX(0, h_in_start); h_in < MIN(H, h_in_start + P); + h_in++) { + for (int32_t w_in = (int32_t)w_in_start; + w_in < MIN(0, (int32_t)(w_in_start + Q)); w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Work on the remaining RIGHT padding + for (uint32_t h_in = MAX(0, h_in_start); h_in < MIN(H, h_in_start + P); + h_in++) { + for (uint32_t w_in = MAX(W, w_in_start); w_in < w_in_start + Q; + w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Copy input data to im2col buffer + // Input channels depend on the output channels assigned to the core + // (each input channel is associated with F_total / C output channels, + // number which corresponds to the "group" parameter in the Conv ONNX + // operator) + for (uint32_t c = ch_out_start / (F_total / C); + c < (ch_out_stop + 1) / (F_total / C); c++) { + // Copy the valid input data to the im2col buffer + for (uint32_t h_in = MAX(0, h_in_start); + h_in < MIN(H, h_in_start + P); h_in++) { + for (uint32_t w_in = MAX(0, w_in_start); + w_in < MIN(W, w_in_start + Q); w_in++) { + uint32_t in_idx = (h_in * W + w_in) * C + c; + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = + pSrcA[in_idx]; + } + } + + // Compute output channels of interest, based on current input channel + // and core + uint32_t lower_f, upper_f; + + if (c * (F_total / C) < ch_out_start) { + lower_f = ch_out_start; + } else { + lower_f = c * (F_total / C); + } + + if ((c + 1) * (F_total / C) < ch_out_stop) { + upper_f = (c + 1) * (F_total / C); + } else { + upper_f = ch_out_stop; + } + + // Perform convolution for the assigned output channels + for (uint32_t f = lower_f; f < upper_f; f++) { + float32_t sum = 0.0f; + uint32_t out_idx = (h_out * W_out + w_out) * F_total + f; + + for (uint32_t im2col_idx = 0; im2col_idx < P * Q; im2col_idx++) { + sum += + im2col_buffer[im2col_idx] * + weight_ptr[(f - ch_out_start) * P * Q + im2col_idx % (P * Q)]; + } + + // Copy the result to the output tensor + pDstC[out_idx] = sum + pSrcBias[f]; + } + } + } + } + } else { + // Work on individual output elements + // (each element depends on a column from the im2col buffer + // and one convolutional filter, stored in memory continuously) + for (uint32_t h_out = 0; h_out < H_out; h_out++) { + for (uint32_t w_out = 0; w_out < W_out; w_out++) { + // Compute height and width starting point + // (depending on stride and padding) + int32_t h_in_start = h_out * SP - pad_top; + int32_t w_in_start = w_out * SQ - pad_left; + + // Initialize the padded part of the im2col buffer with 0 + // Work on the TOP padding + for (int32_t h_in = (int32_t)h_in_start; + h_in < MIN(0, (int32_t)(h_in_start + P)); h_in++) { + for (int32_t w_in = (int32_t)w_in_start; + w_in < (int32_t)(w_in_start + Q); w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Work on the BOTTOM padding + for (uint32_t h_in = MAX(H, h_in_start); h_in < h_in_start + P; + h_in++) { + for (int32_t w_in = (int32_t)w_in_start; + w_in < (int32_t)(w_in_start + Q); w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Work on the remaining LEFT padding + for (uint32_t h_in = MAX(0, h_in_start); h_in < MIN(H, h_in_start + P); + h_in++) { + for (int32_t w_in = (int32_t)w_in_start; + w_in < MIN(0, (int32_t)(w_in_start + Q)); w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Work on the remaining RIGHT padding + for (uint32_t h_in = MAX(0, h_in_start); h_in < MIN(H, h_in_start + P); + h_in++) { + for (uint32_t w_in = MAX(W, w_in_start); w_in < w_in_start + Q; + w_in++) { + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = 0.0f; + } + } + + // Copy input data to im2col buffer + // Input channels depend on the output channels assigned to the core + // (each input channel is associated with F_total / C output channels, + // number which corresponds to the "group" parameter in the Conv ONNX + // operator) + for (uint32_t c = ch_out_start / (F_total / C); + c < (ch_out_stop + 1) / (F_total / C); c++) { + // Copy the valid input data to the im2col buffer + for (uint32_t h_in = MAX(0, h_in_start); + h_in < MIN(H, h_in_start + P); h_in++) { + for (uint32_t w_in = MAX(0, w_in_start); + w_in < MIN(W, w_in_start + Q); w_in++) { + uint32_t in_idx = (h_in * W + w_in) * C + c; + im2col_buffer[(h_in - h_in_start) * Q + (w_in - w_in_start)] = + pSrcA[in_idx]; + } + } + + // Compute output channels of interest, based on current input channel + // and core + uint32_t lower_f, upper_f; + + if (c * (F_total / C) < ch_out_start) { + lower_f = ch_out_start; + } else { + lower_f = c * (F_total / C); + } + + if ((c + 1) * (F_total / C) < ch_out_stop) { + upper_f = (c + 1) * (F_total / C); + } else { + upper_f = ch_out_stop; + } + + // Perform convolution for the assigned output channels + for (uint32_t f = lower_f; f < upper_f; f++) { + float32_t sum = 0.0f; + uint32_t out_idx = (h_out * W_out + w_out) * F_total + f; + + for (uint32_t im2col_idx = 0; im2col_idx < P * Q; im2col_idx++) { + sum += + im2col_buffer[im2col_idx] * + weight_ptr[(f - ch_out_start) * P * Q + im2col_idx % (P * Q)]; + } + + // Copy the result to the output tensor + pDstC[out_idx] = sum; + } + } + } + } + } + + return; +}