diff --git a/.github/workflows/ci-platform-generic.yml b/.github/workflows/ci-platform-generic.yml index 3df422a74..321bed699 100644 --- a/.github/workflows/ci-platform-generic.yml +++ b/.github/workflows/ci-platform-generic.yml @@ -96,3 +96,4 @@ jobs: CCT/CCT_1_16_16_8 CCT/CCT_2_32_32_128_Opset20 testFloatDemoTinyViT + Autoencoder1D diff --git a/CHANGELOG.md b/CHANGELOG.md index 5421cdf52..158138ccf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid ## Unreleased (Planned Release Target: v0.2.1) ### List of Pull Requests +- Support for 1D Autoencoder [#98](https://github.com/pulp-platform/Deeploy/pull/98) - Refactor Logging for Improved Debugging [#115](https://github.com/pulp-platform/Deeploy/pull/115) - Add reuse-tool as an SPDX license header linter [#113](https://github.com/pulp-platform/Deeploy/pull/113) - Bug fixes, API Cleanup and Reduce Compiler Warning on PULP [#112](https://github.com/pulp-platform/Deeploy/pull/112) @@ -158,6 +159,13 @@ This release containing major architectural changes, new platform support, enhan ### Added +- BatchNorm kernel +- ConvTranspose kernel +- MaxPool1D kernel +- Template for 1D Convolution +- Support for float32 data type in the previous kernels +- Float binding for Pad1D kernel +- Test for Autoencoder1D in the CI pipeline - ChimeraDeployer, currently mainly a placeholder - Allocate templates for Chimera - ChimeraPlatform, using appropriate allocation templates and using the generic Parser + Binding for the Add node @@ -291,6 +299,8 @@ This release containing major architectural changes, new platform support, enhan - `dev-requirements.txt` tracking the dependencies of the build system, linting, documentation, and QOL. ### Changed +- FloatConvTemplate file +- Platform.py file - Bump the CMake version to 3.24 as required for the chimera-sdk - Bump GVSoC's version and add chimera simulation target - Rename the generic source util to utils to avoid name collision with chimera-sdk diff --git a/Deeploy/Targets/Generic/Bindings.py b/Deeploy/Targets/Generic/Bindings.py index 24fc8c0d2..6bfe805b3 100644 --- a/Deeploy/Targets/Generic/Bindings.py +++ b/Deeploy/Targets/Generic/Bindings.py @@ -11,19 +11,20 @@ int8_t, int32_t, uint8_t from Deeploy.DeeployTypes import CodeTransformation, NodeBinding from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration -from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, ConvTemplate, DebugPrintTemplate, \ - DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, FloatConvTemplate, FloatDivTemplate, \ - FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \ - FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, FloatReduceMeanTemplate, FloatReluTemplate, \ - FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, \ - MatMulTemplate, MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \ - RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \ - iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate -from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DebugPrintChecker, \ - DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, LayerNormChecker, \ - MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, ReduceSumChecker, \ - ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, SoftmaxChecker, \ - TransposeChecker +from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \ + ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \ + FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \ + FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \ + FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \ + IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \ + PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \ + RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \ + iRMSNormTemplate, iSoftmaxTemplate +from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \ + DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \ + LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \ + ReduceSumChecker, ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, \ + SoftmaxChecker, TransposeChecker BasicTransformer = CodeTransformation([ArgumentStructGeneration(), MemoryManagementGeneration(), FutureGeneration()]) @@ -53,8 +54,14 @@ FloatAddTemplate.referenceTemplate, BasicTransformer) ] -BasicConv1DBinding = NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]), - ConvTemplate.reference1DTemplate, BasicTransformer) +BasicConv1DBindings = [ + NodeBinding(ConvChecker( + [PointerClass(type), PointerClass(type), PointerClass(type)], [PointerClass(type)]), + FloatConvTemplate.reference1DTemplate, BasicTransformer) for type in FloatDataTypes +] + [ + NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]), + ConvTemplate.reference1DTemplate, BasicTransformer) +] BasicDWConv1DBinding = NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]), DWConvTemplate.reference1DTemplate, BasicTransformer) @@ -147,6 +154,11 @@ FloatMatMulTemplate.referenceTemplate, BasicTransformer) ] +BasicMaxPool1DBindings = [ + NodeBinding(MaxPoolChecker([PointerClass(type)], [PointerClass(type)]), FloatMaxPoolTemplate.reference1DTemplate, + BasicTransformer) for type in FloatDataTypes +] + BasicMaxPool2DBindings = [ NodeBinding(MaxPoolChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), MaxPoolTemplate.referenceTemplate, BasicTransformer) @@ -167,7 +179,11 @@ BasicPad1DBindings = [ NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), PadTemplate.reference1DTemplate, BasicTransformer) for type in SignedIntegerDataTypes +] + [ + NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), FloatPadTemplate.reference1DTemplate, + BasicTransformer) for type in FloatDataTypes ] + BasicPad2DBindings = [ NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), PadTemplate.reference2DTemplate, BasicTransformer) for type in SignedIntegerDataTypes @@ -266,3 +282,30 @@ NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, BasicTransformer), ] + +BasicBatchNormBindings = [ + NodeBinding( + BatchNormChecker( + [PointerClass(type), + PointerClass(type), + PointerClass(type), + PointerClass(type), + PointerClass(type)], [PointerClass(type)]), BatchNormalizationTemplate.referenceTemplate, BasicTransformer) + for type in FloatDataTypes +] + +BasicConvTransposeBindings = [ + NodeBinding( + ConvChecker( + [PointerClass(type), PointerClass(type), PointerClass(type)], # input, weight, bias + [PointerClass(type)]), + ConvTransposeTemplate.referenceTemplate, + BasicTransformer) for type in FloatDataTypes +] + [ + NodeBinding( + ConvChecker( + [PointerClass(type), PointerClass(type)], # input, weight + [PointerClass(type)]), + ConvTransposeTemplate.referenceTemplate, + BasicTransformer) for type in FloatDataTypes +] diff --git a/Deeploy/Targets/Generic/Layers.py b/Deeploy/Targets/Generic/Layers.py index e01f5e79d..c924895c1 100644 --- a/Deeploy/Targets/Generic/Layers.py +++ b/Deeploy/Targets/Generic/Layers.py @@ -618,3 +618,64 @@ class DequantLayer(ONNXLayer): def __init__(self, maps: List[NodeMapper]): super().__init__(maps) + + +class BatchNormalizationLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + def computeOps(self): + # 5 operations per element: sub, mul, add, sqrt, div + B = self.mapper.parser.operatorRepresentation['batch_size'] + C = self.mapper.parser.operatorRepresentation['channel_size'] + W = self.mapper.parser.operatorRepresentation['window_size'] + return B * C * W * 5 + + +class ConvTransposeLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + def computeShapes(self, inputShapes: Shape, outputShapes: Shape, operatorRepresentation, + channels_first) -> Tuple[Shape, Shape]: + """ + Infers output shapes for ConvTranspose using only static info. + - inputShapes[0]: input tensor shape (e.g., [N, C_in, W] for 1D, [N, C_in, H, W] for 2D) + - inputShapes[1]: weight tensor shape (e.g., [C_in, C_out // group, kW] for 1D) + - outputShapes[0]: output tensor shape (to be updated) + """ + newInputShapes = list(inputShapes) + newOutputShapes = list(outputShapes) + group = operatorRepresentation.get('group', 1) + weight_shape = inputShapes[1] + + if newOutputShapes and len(newOutputShapes[0]) >= 2: + # For 1D: weight_shape = [C_in, C_out // group, kW] + # For 2D: weight_shape = [C_in, C_out // group, kH, kW] + ch_out = weight_shape[1] * group + if channels_first: + newOutputShapes[0][1] = ch_out + else: + newOutputShapes[0][-1] = ch_out + + return newInputShapes, newOutputShapes + + def computeOps(self): + opRep = self.mapper.parser.operatorRepresentation + + groups = opRep.get('group', 1) + kernel_shape = np.prod(opRep['kernel_shape']) # es. [3, 3] -> 9 + ch_in = opRep['ch_im_in'] + ch_out = opRep['ch_im_out'] + + opsPerPx = int(kernel_shape * ch_in * ch_out / groups) * 2 + + # ConvTranspose upscales spatial dims, quindi num pixel viene da output + if 'dim_im_out_y' in opRep: + numPx = opRep['dim_im_out_x'] * opRep['dim_im_out_y'] + else: + numPx = opRep['dim_im_out_x'] + + return numPx * opsPerPx diff --git a/Deeploy/Targets/Generic/Parsers.py b/Deeploy/Targets/Generic/Parsers.py index 3c3a3472c..adc48ffe1 100644 --- a/Deeploy/Targets/Generic/Parsers.py +++ b/Deeploy/Targets/Generic/Parsers.py @@ -221,6 +221,48 @@ def parseNodeCtxt(self, return ctxt, True +class MaxPool1DParser(MaxPoolParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + ret = super().parseNode(node) + wellFormed = False + if ret: + pads = self.operatorRepresentation['pads'] + kernel_shape = self.operatorRepresentation['kernel_shape'] + strides = self.operatorRepresentation['strides'] + # 1D: pads should be length 2, kernel_shape length 1, strides length 1 + if len(pads) == 2 and len(kernel_shape) == 1 and len(strides) == 1: + wellFormed = True + self.operatorRepresentation['padding_y'] = int(pads[0]) + self.operatorRepresentation['padding_y_right'] = int(pads[1]) + self.operatorRepresentation['stride_y'] = int(strides[0]) + self.operatorRepresentation['dim_kernel_y'] = int(kernel_shape[0]) + return wellFormed + + def parseNodeCtxt(self, ctxt, node, channels_first = True): + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + if ret: + data_in = newCtxt.lookup(self.operatorRepresentation['data_in']) + data_out = newCtxt.lookup(self.operatorRepresentation['data_out']) + self.operatorRepresentation['batch'] = data_in.shape[0] + if channels_first: + self.operatorRepresentation['ch_im_in'] = data_in.shape[1] + self.operatorRepresentation['dim_im_in_y'] = data_in.shape[2] + self.operatorRepresentation['ch_im_out'] = data_out.shape[1] + self.operatorRepresentation['dim_im_out_y'] = data_out.shape[2] + else: + self.operatorRepresentation['ch_im_in'] = data_in.shape[2] + self.operatorRepresentation['dim_im_in_y'] = data_in.shape[1] + self.operatorRepresentation['ch_im_out'] = data_out.shape[2] + self.operatorRepresentation['dim_im_out_y'] = data_out.shape[1] + if len(data_in.shape) == 3 and len(data_out.shape) == 3: + return newCtxt, True + return ctxt, False + + class MaxPool2DParser(MaxPoolParser): def __init__(self): @@ -298,7 +340,12 @@ def parseNode(self, node: gs.Node) -> bool: if ret: self.operatorRepresentation['mode'] = node.attrs['mode'] - self.operatorRepresentation['pads'] = node.attrs['pads'] + + try: + self.operatorRepresentation['pads'] = [int(p) for p in node.attrs['pads']] + except Exception as e: + self.operatorRepresentation['pads'] = node.attrs['pads'] + self.operatorRepresentation['value'] = node.attrs['value'] return ret @@ -1325,6 +1372,8 @@ def parseNodeCtxt(self, self.operatorRepresentation['batch'] = data_in.shape[0] self.operatorRepresentation['dim_im_in_x'] = 1 + + # Necessary, since we use the same Convlayer for all convolutions self.operatorRepresentation['dim_im_out_x'] = 1 if channels_first: @@ -1338,6 +1387,11 @@ def parseNodeCtxt(self, self.operatorRepresentation['ch_im_out'] = data_out.shape[2] self.operatorRepresentation['dim_im_out_y'] = data_out.shape[1] + self.operatorRepresentation[ + 'batchOffsetIn'] = self.operatorRepresentation['ch_im_in'] * self.operatorRepresentation['dim_im_in_y'] + self.operatorRepresentation['batchOffsetOut'] = self.operatorRepresentation[ + 'ch_im_out'] * self.operatorRepresentation['dim_im_out_y'] + if len(data_in.shape) == 3 and len(weight.shape) == 3: return newCtxt, True @@ -2136,7 +2190,20 @@ def parseNodeCtxt(self, if ret: inputs = ['data_in', 'weight'] + + # Handle bias, if present + if len(node.inputs) > 2: + inputs.append("bias") + self.operatorRepresentation["has_bias"] = 1 + else: + self.operatorRepresentation["has_bias"] = 0 + self.operatorRepresentation["bias"] = "NULL" + for idx, inputNode in enumerate(node.inputs): + if idx >= len(inputs): + raise IndexError( + f"Index {idx} out of range for inputs of length {len(inputs)} in node {inputNode.name}") + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name return newCtxt, True @@ -2555,3 +2622,171 @@ def parseNodeCtxt(self, self.operatorRepresentation['lr'] = node.attrs['lr'] return ctxt, True + + +class BatchNormParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + # Verify the attributes (epsilon is mandatory, momentum and training_mode are optional) + if 'epsilon' not in node.attrs: + return False + # Common Inputs: 5 (X, scale, B, mean, var) + if len(node.inputs) < 5: + return False + + # Save the attributes, default values are provided if not present + self.operatorRepresentation['epsilon'] = node.attrs.get('epsilon', 1e-5) + self.operatorRepresentation['momentum'] = node.attrs.get('momentum', 0.9) + self.operatorRepresentation['training_mode'] = node.attrs.get('training_mode', 0) + + return True + + def parseNodeCtxt(self, ctxt, node: gs.Node, channels_first: bool = True): + inputs = ['data_in', 'scale', 'bias', 'mean', 'variance'] + outputs = ['data_out'] + + for idx, inputNode in enumerate(node.inputs[:5]): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + + # Output (Y) + self.operatorRepresentation[outputs[0]] = ctxt.lookup(node.outputs[0].name).name + + input_shape = ctxt.lookup(node.inputs[0].name).shape + # Save input shape information + self.operatorRepresentation['batch_size'] = input_shape[0] + self.operatorRepresentation['channel_size'] = input_shape[1] + self.operatorRepresentation['window_size'] = input_shape[2] + + return ctxt, True + + +class ConvTransposeParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + # Extract ONNX attributes with defaults + strides = node.attrs.get('strides', [1]) + + pads = node.attrs.get('pads', [0, 0]) + kernel_shape = node.attrs.get('kernel_shape', None) + dilations = node.attrs.get('dilations', [1]) + group = node.attrs.get('group', 1) + + # Check for required attributes + wellFormed = (kernel_shape is not None and len(node.outputs) == 1) + if wellFormed: + self.operatorRepresentation['strides'] = strides + self.operatorRepresentation['pads'] = pads + self.operatorRepresentation['kernel_shape'] = kernel_shape + self.operatorRepresentation['dilations'] = dilations + self.operatorRepresentation['group'] = group + self.operatorRepresentation['nodeName'] = node.name + self.operatorRepresentation['nodeOp'] = node.op + return wellFormed + + def parseNodeCtxt(self, ctxt: NetworkContext, node: gs.Node, channels_first: bool = True): + # Register buffer names for codegen + self.operatorRepresentation['data_in'] = node.inputs[0].name + self.operatorRepresentation['weight'] = node.inputs[1].name + self.operatorRepresentation['data_out'] = node.outputs[0].name + if len(node.inputs) == 3: + self.operatorRepresentation['bias'] = node.inputs[2].name + self.operatorRepresentation['has_bias'] = "true" + else: + self.operatorRepresentation['has_bias'] = "false" + # Get output shape from context + data_out = ctxt.lookup(node.outputs[0].name) + out_shape = data_out.shape + if len(out_shape) == 3: + self.operatorRepresentation['dim_im_out_x'] = out_shape[2] + elif len(out_shape) == 4: + self.operatorRepresentation['dim_im_out_x'] = out_shape[2] + self.operatorRepresentation['dim_im_out_y'] = out_shape[3] + + stride_x, stride_y = 1, 1 + if "strides" in node.attrs: + stride_y = node.attrs["strides"][0] + stride_x = node.attrs["strides"][1] if len(node.attrs["strides"]) > 1 else stride_y + self.operatorRepresentation["stride_y"] = stride_y + self.operatorRepresentation["stride_x"] = stride_x + + if "kernel_shape" in node.attrs: + kernel_shape = node.attrs["kernel_shape"] + kernel_shape_x = kernel_shape[0] + # For 2D, kernel_shape may have two elements + kernel_shape_y = kernel_shape[1] if len(kernel_shape) > 1 else kernel_shape_x + else: + kernel_shape_x = 1 + kernel_shape_y = 1 + + data_in = ctxt.lookup(node.inputs[0].name) + data_out = ctxt.lookup(node.outputs[0].name) + in_shape = data_in.shape + out_shape = data_out.shape + + self.operatorRepresentation['ch_im_in'] = in_shape[1] + self.operatorRepresentation['dim_im_in_y'] = in_shape[2] + self.operatorRepresentation['ch_im_out'] = out_shape[1] + self.operatorRepresentation['dim_im_out_y'] = out_shape[2] + + self.operatorRepresentation[ + 'batchOffsetIn'] = self.operatorRepresentation['ch_im_in'] * self.operatorRepresentation['dim_im_in_y'] + self.operatorRepresentation[ + 'batchOffsetOut'] = self.operatorRepresentation['ch_im_out'] * self.operatorRepresentation['dim_im_out_y'] + return ctxt, True + + +class ConvTranspose1DParser(ConvTransposeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + # 1D ConvTranspose expects 3D input/output and 3D weight + wellFormed = super().parseNode(node) + ret = False + if wellFormed: + ret = all([ + # Make sure strides are 2D + len(node.attrs['strides']) == 1, + len(node.attrs['pads']) == 2, + len(node.attrs['dilations']) == 1, + ]) + if ret: + + self.operatorRepresentation['kernel_shape'] = node.attrs['kernel_shape'] + self.operatorRepresentation['dim_kernel_y'] = int(self.operatorRepresentation['kernel_shape'][0]) + self.operatorRepresentation['dilation_y'] = int(self.operatorRepresentation['dilations'][0]) + self.operatorRepresentation['padding_y'] = int(self.operatorRepresentation['pads'][0]) + self.operatorRepresentation['stride_y'] = int(self.operatorRepresentation['strides'][0]) + + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if ret: + data_in = newCtxt.lookup(node.inputs[0].name) + data_out = newCtxt.lookup(node.outputs[0].name) + in_shape = data_in.shape + out_shape = data_out.shape + self.operatorRepresentation['batch'] = in_shape[0] + self.operatorRepresentation['ch_im_in'] = in_shape[1] + self.operatorRepresentation['dim_im_in_y'] = in_shape[2] + self.operatorRepresentation['ch_im_out'] = out_shape[1] + self.operatorRepresentation['dim_im_out_y'] = out_shape[2] + self.operatorRepresentation[ + "batchOffsetIn"] = self.operatorRepresentation["ch_im_in"] * self.operatorRepresentation["dim_im_in_y"] + self.operatorRepresentation["batchOffsetOut"] = self.operatorRepresentation[ + "ch_im_out"] * self.operatorRepresentation["dim_im_out_y"] + return newCtxt, True + return ctxt, False diff --git a/Deeploy/Targets/Generic/Platform.py b/Deeploy/Targets/Generic/Platform.py index c09b89df9..a15b3db2e 100644 --- a/Deeploy/Targets/Generic/Platform.py +++ b/Deeploy/Targets/Generic/Platform.py @@ -6,30 +6,33 @@ RemoveEmptyConvBiasPass from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NodeMapper, NodeTemplate, \ StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer -from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicConcatBindings, BasicConv1DBinding, \ - BasicConv2DBindings, BasicDebugPrintBindings, BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, \ - BasicDWConv2DBindings, BasicGatherBindings, BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, \ - BasicITASoftmaxBinding, BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool2DBindings, BasicMulBindings, \ +from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicBatchNormBindings, BasicConcatBindings, \ + BasicConv1DBindings, BasicConv2DBindings, BasicConvTransposeBindings, BasicDebugPrintBindings, \ + BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, BasicGatherBindings, \ + BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, \ + BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, \ BasicPad1DBindings, BasicPad2DBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, \ BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, \ BasicSliceBindings, BasicSoftmaxBindings, BasicTransposeBindings, DummyBinding -from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, DebugPrintLayer, DequantLayer, DivLayer, \ - GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, \ - QuantLayer, ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, \ - RQSiGELULayer, SliceLayer, SoftmaxLayer, TransposeLayer -from Deeploy.Targets.Generic.Parsers import AddParser, ConcatParser, DebugParser, DequantParser, DivParser, \ - DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, GenericConv2DParser, \ - GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, IntegerDivParser, \ - ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MulParser, Pad1DParser, Pad2DParser, \ - QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, RQIntegerDivParser, \ - RQSiGELUParser, SliceParser, SoftmaxParser, TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser +from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, ConcatLayer, ConvLayer, \ + ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, \ + LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \ + ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, SoftmaxLayer, \ + TransposeLayer +from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, ConcatParser, ConvTranspose1DParser, \ + DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, \ + GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, \ + IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, \ + Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, \ + ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, TransposeParser, UnsqueezeParser, \ + iLayerNormParser, iSoftmaxParser from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \ ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \ iGELURequantMergePass AddMapper = NodeMapper(AddParser(), BasicAddBindings) -Conv1DMapper = NodeMapper(GenericConv1DParser(), [BasicConv1DBinding]) +Conv1DMapper = NodeMapper(GenericConv1DParser(), BasicConv1DBindings) Conv2DMapper = NodeMapper(GenericConv2DParser(), BasicConv2DBindings) ConcatMapper = NodeMapper(ConcatParser(), BasicConcatBindings) DebugMapper = NodeMapper(DebugParser(), BasicDebugPrintBindings) @@ -47,6 +50,7 @@ ITAPartialMaxMapper = NodeMapper(ITAPartialMaxParser(), [BasicITAPartialSoftmaxBinding]) MatMulMapper = NodeMapper(MatMulParser(), BasicMatMulBindings) MaxPoolMapper = NodeMapper(GenericMaxPool2DParser(), BasicMaxPool2DBindings) +MaxPool1DMapper = NodeMapper(MaxPool1DParser(), BasicMaxPool1DBindings) MulMapper = NodeMapper(MulParser(), BasicMulBindings) Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings) Pad2DMapper = NodeMapper(Pad2DParser(), BasicPad2DBindings) @@ -63,7 +67,8 @@ UnsqueezeMapper = NodeMapper(UnsqueezeParser(), BasicReshapeBindings) QuantMapper = NodeMapper(QuantParser(), BasicQuantBindings) DequantMapper = NodeMapper(DequantParser(), BasicDequantBindings) - +BatchNormalizationMapper = NodeMapper(BatchNormParser(), BasicBatchNormBindings) +ConvTransposeMapper = NodeMapper(ConvTranspose1DParser(), BasicConvTransposeBindings) SliceMapper = NodeMapper(SliceParser(), BasicSliceBindings) # Dummy nodes are intended for development purposes only! @@ -91,7 +96,7 @@ 'ITAPartialMax': ITAMaxLayer([ITAPartialMaxMapper]), 'MatMul': GEMMLayer([MatMulMapper]), 'MatMulInteger': MatMulLayer([MatMulMapper]), - 'MaxPool': MaxPoolLayer([MaxPoolMapper]), + 'MaxPool': MaxPoolLayer([MaxPool1DMapper, MaxPoolMapper]), 'Mul': MulLayer([MulMapper]), 'Pad': PadLayer([Pad1DMapper, Pad2DMapper]), 'ReduceMean': ReduceMeanLayer([ReduceMeanMapper]), @@ -106,7 +111,9 @@ 'Unsqueeze': ReshapeLayer([UnsqueezeMapper]), 'Slice': SliceLayer([SliceMapper]), 'Quant': QuantLayer([QuantMapper]), - 'Dequant': DequantLayer([DequantMapper]) + 'Dequant': DequantLayer([DequantMapper]), + 'BatchNormalization': BatchNormalizationLayer([BatchNormalizationMapper]), + 'ConvTranspose': ConvTransposeLayer([ConvTransposeMapper]) # # For example, you can use the DummpyMapper, in case you want to test # # deployment or optimizations with GlobalAveragePool nodes but did not yet # # implement the corresponding kernel diff --git a/Deeploy/Targets/Generic/Templates/BatchNormalizationTemplate.py b/Deeploy/Targets/Generic/Templates/BatchNormalizationTemplate.py new file mode 100644 index 000000000..5377c91ca --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/BatchNormalizationTemplate.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.DeeployTypes import NodeTemplate + +referenceTemplate = NodeTemplate(""" +// BatchNorm (Name: ${nodeName}, Op: ${nodeOp}) +BEGIN_SINGLE_CORE + BatchNorm_fp32( + ${data_in}, ${scale}, ${bias}, ${mean}, ${variance}, + ${data_out}, ${batch_size}, ${channel_size}, ${window_size} + ); +END_SINGLE_CORE +""") diff --git a/Deeploy/Targets/Generic/Templates/ConvTransposeTemplate.py b/Deeploy/Targets/Generic/Templates/ConvTransposeTemplate.py new file mode 100644 index 000000000..9bf864c91 --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/ConvTransposeTemplate.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.DeeployTypes import NodeTemplate + +referenceTemplate = NodeTemplate(""" +<% +batchOffsetIn = ch_im_in * dim_im_in_y +batchOffsetOut = ch_im_out * dim_im_out_y +%> + +// 1D Transposed Conv (Name: ${nodeName}, Op: ${nodeOp}) +BEGIN_SINGLE_CORE + ${data_in_type.typeName} ref_${data_out}_${data_in} = ${data_in}; + ${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out}; + + for (uint32_t n=0; n<${batch}; ++n) { + ConvTranspose1d_fp32( + ref_${data_out}_${data_in}, ${ch_im_in}, ${dim_im_in_y}, + ${weight}, ${ch_im_out}, ${dim_kernel_y}, + ${stride_y}, + ${bias}, ${has_bias}, + ref_${data_out}_${data_out}, ${dim_im_out_y} + ); + + ref_${data_out}_${data_in} += ${batchOffsetIn}; + ref_${data_out}_${data_out} += ${batchOffsetOut}; + } +END_SINGLE_CORE +""") diff --git a/Deeploy/Targets/Generic/Templates/FloatConvTemplate.py b/Deeploy/Targets/Generic/Templates/FloatConvTemplate.py index f1cb7f15a..7519d33a2 100644 --- a/Deeploy/Targets/Generic/Templates/FloatConvTemplate.py +++ b/Deeploy/Targets/Generic/Templates/FloatConvTemplate.py @@ -29,3 +29,29 @@ } END_SINGLE_CORE """) + +reference1DTemplate = NodeTemplate(""" +<% +batchOffsetIn = ch_im_in * dim_im_in_y +batchOffsetOut = ch_im_out * dim_im_out_y +%> + // 1D FP Conv (Name: ${nodeName}, Op: ${nodeOp}) + BEGIN_SINGLE_CORE + ${data_in_type.typeName} ref_${data_out}_${data_in} = ${data_in}; + ${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out}; + for (uint32_t n=0; n<${batch}; ++n) { + Conv1d_fp${data_in_type.referencedType.typeWidth}_fp${weight_type.referencedType.typeWidth}_fp${data_out_type.referencedType.typeWidth}( + ref_${data_out}_${data_in}, ${ch_im_in}, ${dim_im_in_y}, + ${weight}, ${ch_im_out}, ${dim_kernel_y}, + ${stride_y}, + ${bias}, + ${has_bias}, + ref_${data_out}_${data_out}, + ${dim_im_out_y} + ); + + ref_${data_out}_${data_in} += ${batchOffsetIn}; + ref_${data_out}_${data_out} += ${batchOffsetOut}; + } + END_SINGLE_CORE + """) \ No newline at end of file diff --git a/Deeploy/Targets/Generic/Templates/FloatMaxPoolTemplate.py b/Deeploy/Targets/Generic/Templates/FloatMaxPoolTemplate.py index b5401d174..1eef5e0f4 100644 --- a/Deeploy/Targets/Generic/Templates/FloatMaxPoolTemplate.py +++ b/Deeploy/Targets/Generic/Templates/FloatMaxPoolTemplate.py @@ -20,3 +20,24 @@ } END_SINGLE_CORE """) + +reference1DTemplate = NodeTemplate(""" +<% +batchOffsetIn = ch_im_in * dim_im_in_y +batchOffsetOut = ch_im_out * dim_im_out_y +%> + // 1D Float MaxPool (Name: ${nodeName}, Op: ${nodeOp}) + BEGIN_SINGLE_CORE + ${data_in_type.typeName} ref_${data_out}_${data_in} = ${data_in}; + ${data_out_type.typeName} ref_${data_out}_${data_out} = ${data_out}; + for (uint32_t n=0; n<${batch}; ++n) { + MaxPool1d_fp32_fp32( + ref_${data_out}_${data_in}, ${ch_im_in}, ${dim_im_in_y}, + ${dim_kernel_y}, ${stride_y}, + ref_${data_out}_${data_out} + ); + ref_${data_out}_${data_in} += ${batchOffsetIn}; + ref_${data_out}_${data_out} += ${batchOffsetOut}; + } + END_SINGLE_CORE +""") \ No newline at end of file diff --git a/Deeploy/Targets/Generic/Templates/FloatPadTemplate.py b/Deeploy/Targets/Generic/Templates/FloatPadTemplate.py index c1bd56764..ad528910b 100644 --- a/Deeploy/Targets/Generic/Templates/FloatPadTemplate.py +++ b/Deeploy/Targets/Generic/Templates/FloatPadTemplate.py @@ -52,3 +52,42 @@ %endif END_SINGLE_CORE """) + +reference1DTemplate = NodeTemplate(""" +<% + x_offset_out = dim_im_out_ch*(pad_y) + width = dim_im_in_ch*dim_im_in_y + + startPosX = x_offset_out + +batchOffsetOut = dim_im_out_ch * dim_im_out_y +%> + +// 1D Float Pad (Name: ${nodeName}, Op: ${nodeOp}) +BEGIN_SINGLE_CORE + for (uint32_t i = 0; i < ${data_out_size}; i++) { + ${data_out}[i] = ${value}; + } + uint32_t xoffset_${data_out}_${data_in}; + uint32_t offset_in_${data_out}_${data_in} = 0; + + % if channels_first: + // NCHW Layout + for(uint32_t n=0; n<${batch}; n++){ + xoffset_${data_out}_${data_in} = n*${batchOffsetOut} +${pad_y}; + for (uint32_t c=0; c<${dim_im_in_ch}; ++c) { + memcpy(${data_out} + xoffset_${data_out}_${data_in}, ${data_in}+offset_in_${data_out}_${data_in}, ${dim_im_in_y}*sizeof(${data_out_type.referencedType.typeName})); + xoffset_${data_out}_${data_in} += ${dim_im_out_y}; + offset_in_${data_out}_${data_in} += ${dim_im_in_y}; + } + } + % else: + // NHWC Layout + for(uint32_t n=0; n<${batch}; n++){ + xoffset_${data_out}_${data_in} = n*${batchOffsetOut} + ${startPosX}; + memcpy(${data_out}+xoffset_${data_out}_${data_in}, ${data_in}+offset_in_${data_out}_${data_in}, ${width}*sizeof(${data_out_type.referencedType.typeName})); + offset_in_${data_out}_${data_in} += ${width}; + } + %endif +END_SINGLE_CORE +""") diff --git a/Deeploy/Targets/Generic/TypeCheckers.py b/Deeploy/Targets/Generic/TypeCheckers.py index 8f3a12ec8..c2c8d436f 100644 --- a/Deeploy/Targets/Generic/TypeCheckers.py +++ b/Deeploy/Targets/Generic/TypeCheckers.py @@ -596,3 +596,17 @@ def _inferNumLevels(self, inputs: List[VariableBuffer], def _inferSignedness(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> Optional[List[bool]]: return [True] + + +class BatchNormChecker(SignPropTypeChecker): + + def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]): + super().__init__(input_types, output_types) + + def _inferNumLevels(self, inputs: List[VariableBuffer], + operatorRepresentation: OperatorRepresentation) -> List[int]: + return [2**(self.input_types[0].referencedType.typeWidth)] + + def _inferSignedness(self, inputs: List[VariableBuffer], + operatorRepresentation: OperatorRepresentation) -> List[bool]: + return [True] diff --git a/Deeploy/Targets/MemPool/Platform.py b/Deeploy/Targets/MemPool/Platform.py index 4f1a98298..48599736f 100644 --- a/Deeploy/Targets/MemPool/Platform.py +++ b/Deeploy/Targets/MemPool/Platform.py @@ -8,7 +8,7 @@ from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NodeMapper, NodeTemplate, \ StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer -from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicConv1DBinding, BasicConv2DBindings, \ +from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicConv1DBindings, BasicConv2DBindings, \ BasicDebugPrintBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, BasicGatherBindings, \ BasicGELUBindings, BasicLayerNormBindings, BasicMulBindings, BasicPad1DBindings, BasicPad2DBindings, \ BasicReduceMeanBindings, BasicReduceSumBindings, BasicReshapeBindings, BasicRQIntegerDivBinding, \ @@ -37,7 +37,7 @@ # Fallback bindings from the generic platform # (they support a wider range of attribute values) -GenericConv1D_Mapper = NodeMapper(GenericConv1DParser(), [BasicConv1DBinding]) +GenericConv1D_Mapper = NodeMapper(GenericConv1DParser(), BasicConv1DBindings) GenericDWConv1D_Mapper = NodeMapper(GenericDWConv1DParser(), [BasicDWConv1DBinding]) GenericConv2D_Mapper = NodeMapper(GenericConv2DParser(), BasicConv2DBindings) GenericDWConv2D_Mapper = NodeMapper(GenericDWConv2DParser(), BasicDWConv2DBindings) diff --git a/DeeployTest/Tests/Autoencoder1D/inputs.npz b/DeeployTest/Tests/Autoencoder1D/inputs.npz new file mode 100644 index 000000000..cc639dab2 Binary files /dev/null and b/DeeployTest/Tests/Autoencoder1D/inputs.npz differ diff --git a/DeeployTest/Tests/Autoencoder1D/network.onnx b/DeeployTest/Tests/Autoencoder1D/network.onnx new file mode 100644 index 000000000..d70e48e6d Binary files /dev/null and b/DeeployTest/Tests/Autoencoder1D/network.onnx differ diff --git a/DeeployTest/Tests/Autoencoder1D/outputs.npz b/DeeployTest/Tests/Autoencoder1D/outputs.npz new file mode 100644 index 000000000..13e8f46fa Binary files /dev/null and b/DeeployTest/Tests/Autoencoder1D/outputs.npz differ diff --git a/TargetLibraries/Generic/inc/DeeployBasicMath.h b/TargetLibraries/Generic/inc/DeeployBasicMath.h index f647c833e..288cb419a 100644 --- a/TargetLibraries/Generic/inc/DeeployBasicMath.h +++ b/TargetLibraries/Generic/inc/DeeployBasicMath.h @@ -22,6 +22,7 @@ #include #include + #include #include #include @@ -31,6 +32,8 @@ #include "types.h" #include "utils.h" +#include "kernel/BatchNorm.h" +#include "kernel/ConvTranspose1d_fp32.h" #include "kernel/Convolution.h" #include "kernel/DWConvolution.h" #include "kernel/Div.h" @@ -40,10 +43,12 @@ #include "kernel/Layernorm.h" #include "kernel/MatMul.h" #include "kernel/MaxPool.h" +#include "kernel/MaxPool1d.h" #include "kernel/RMSNorm.h" #include "kernel/RQDiv.h" #include "kernel/RQGELU.h" #include "kernel/RQHardswish.h" +#include "kernel/Relu.h" #include "kernel/RequantShift.h" #include "kernel/Softmax.h" diff --git a/TargetLibraries/Generic/inc/kernel/BatchNorm.h b/TargetLibraries/Generic/inc/kernel/BatchNorm.h new file mode 100644 index 000000000..72703f5fe --- /dev/null +++ b/TargetLibraries/Generic/inc/kernel/BatchNorm.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef BATCHNORM_H +#define BATCHNORM_H + +#include +#include + +void BatchNorm_fp32(const float32_t *input, const float32_t *gamma, + const float32_t *beta, const float32_t *mean, + const float32_t *var, float32_t *output, int N, int C, + int L); + +#endif // BATCHNORM_H diff --git a/TargetLibraries/Generic/inc/kernel/ConvTranspose1d_fp32.h b/TargetLibraries/Generic/inc/kernel/ConvTranspose1d_fp32.h new file mode 100644 index 000000000..40ef06599 --- /dev/null +++ b/TargetLibraries/Generic/inc/kernel/ConvTranspose1d_fp32.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef CONV_TRANSPOSE1D_FP32_H +#define CONV_TRANSPOSE1D_FP32_H + +#include +#include + +void ConvTranspose1d_fp32(const float32_t *input, uint32_t C_in, uint32_t W_in, + const float32_t *weight, uint32_t C_out, uint32_t K, + uint32_t stride, const float32_t *bias, bool has_bias, + float32_t *output, uint32_t W_out); + +#endif // CONV_TRANSPOSE1D_FP32_H diff --git a/TargetLibraries/Generic/inc/kernel/Convolution.h b/TargetLibraries/Generic/inc/kernel/Convolution.h index f86e8dcd7..8c1d2388b 100644 --- a/TargetLibraries/Generic/inc/kernel/Convolution.h +++ b/TargetLibraries/Generic/inc/kernel/Convolution.h @@ -43,4 +43,13 @@ void Conv2d_fp32_fp32_fp32_NCHW(const float *__restrict__ pSrcA, uint32_t C, uint32_t SQ, const float *__restrict__ pSrcBias, const bool has_bias, float *__restrict__ pDstC); +void Conv1d_fp32_fp32_fp32( + const float32_t *__restrict__ pSrcA, // Input: [C_in, W_in] + uint32_t C_in, uint32_t W_in, + const float32_t *__restrict__ pSrcB, // Weights: [C_out, C_in, K] + uint32_t C_out, uint32_t K, uint32_t stride, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, // Output: [C_out, W_out] + uint32_t W_out); + #endif //__DEEPLOY_BASIC_MATH_CONVOLUTION_KERNEL_HEADER_ diff --git a/TargetLibraries/Generic/inc/kernel/MaxPool1d.h b/TargetLibraries/Generic/inc/kernel/MaxPool1d.h new file mode 100644 index 000000000..26d5e8e46 --- /dev/null +++ b/TargetLibraries/Generic/inc/kernel/MaxPool1d.h @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef __DEEPLOY_BASIC_MATH_MAXPOOL1D_KERNEL_HEADER_ +#define __DEEPLOY_BASIC_MATH_MAXPOOL1D_KERNEL_HEADER_ + +#include "DeeployBasicMath.h" + +void MaxPool1d_fp32_fp32(float32_t const *__restrict__ pSrcA, uint32_t C, + uint32_t W, uint32_t K, uint32_t S, + float32_t *__restrict__ pDstC); + +#endif \ No newline at end of file diff --git a/TargetLibraries/Generic/src/BatchNorm_fp32.c b/TargetLibraries/Generic/src/BatchNorm_fp32.c new file mode 100644 index 000000000..9b30a3020 --- /dev/null +++ b/TargetLibraries/Generic/src/BatchNorm_fp32.c @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +// +// SPDX-License-Identifier: Apache-2.0 + +#include "DeeployBasicMath.h" + +void BatchNorm_fp32(const float32_t *input, const float32_t *gamma, + const float32_t *beta, const float32_t *mean, + const float32_t *var, float32_t *output, int N, int C, + int L) { + const float epsilon = 1e-5f; +#pragma omp parallel for + for (int c = 0; c < C; ++c) { + float32_t c_mean = mean[c]; + float32_t c_var = var[c]; + float32_t c_gamma = gamma[c]; + float32_t c_beta = beta[c]; + float32_t denom = sqrtf(c_var + epsilon); + for (int n = 0; n < N; ++n) { + for (int l = 0; l < L; ++l) { + int index = n * C * L + c * L + l; + float32_t x = input[index]; + float32_t norm = (x - c_mean) / denom; + output[index] = c_gamma * norm + c_beta; + } + } + } +} diff --git a/TargetLibraries/Generic/src/ConvTranspose1d_fp32.c b/TargetLibraries/Generic/src/ConvTranspose1d_fp32.c new file mode 100644 index 000000000..362058734 --- /dev/null +++ b/TargetLibraries/Generic/src/ConvTranspose1d_fp32.c @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +// +// SPDX-License-Identifier: Apache-2.0 + +#include "DeeployBasicMath.h" + +void ConvTranspose1d_fp32(const float32_t *input, uint32_t C_in, uint32_t W_in, + const float32_t *weight, uint32_t C_out, uint32_t K, + uint32_t stride, const float32_t *bias, bool has_bias, + float32_t *output, uint32_t W_out) { + /* + input: [C_in, W_in] + weight: [C_in, C_out, K] + output: [C_out, W_out] + bias: [C_out] optionally + + */ + + // Output initialization + for (uint32_t c = 0; c < C_out; ++c) { + for (uint32_t w = 0; w < W_out; ++w) { + output[c * W_out + w] = 0.0f; + } + } + + // For each output channel + for (uint32_t cout = 0; cout < C_out; ++cout) { + // For each input channel + for (uint32_t cin = 0; cin < C_in; ++cin) { + // For each input width + for (uint32_t w_in = 0; w_in < W_in; ++w_in) { + float32_t val = input[cin * W_in + w_in]; + // Transposed convolution: output width is calculated based on stride + for (uint32_t k = 0; k < K; ++k) { + uint32_t w_out = w_in * stride + k; + if (w_out < W_out) { + // weight indexing: weight[cin, cout, k] + float32_t wgt = weight[cin * (C_out * K) + cout * K + k]; + output[cout * W_out + w_out] += val * wgt; + } + } + } + } + if (has_bias) { + for (uint32_t w = 0; w < W_out; ++w) { + output[cout * W_out + w] += bias[cout]; + } + } + } +} diff --git a/TargetLibraries/Generic/src/Convolution_fp32.c b/TargetLibraries/Generic/src/Convolution_fp32.c index 172749bce..e073e1812 100644 --- a/TargetLibraries/Generic/src/Convolution_fp32.c +++ b/TargetLibraries/Generic/src/Convolution_fp32.c @@ -66,3 +66,32 @@ void Conv2d_fp32_fp32_fp32_NCHW(const float32_t *__restrict__ pSrcA, uint32_t C, } } } + +void Conv1d_fp32_fp32_fp32( + const float32_t *__restrict__ pSrcA, // Input: [C_in, W_in] + uint32_t C_in, uint32_t W_in, + const float32_t *__restrict__ pSrcB, // Weights: [C_out, C_in, K] + uint32_t C_out, uint32_t K, uint32_t stride, + const float32_t *__restrict__ pSrcBias, const bool has_bias, + float32_t *__restrict__ pDstC, // Output: [C_out, W_out] + uint32_t W_out) { + uint32_t c_out, c_in, w_out, k, w_in; + for (c_out = 0; c_out < C_out; ++c_out) { + for (w_out = 0; w_out < W_out; ++w_out) { + float32_t sum = 0.0f; + for (c_in = 0; c_in < C_in; ++c_in) { + for (k = 0; k < K; ++k) { + w_in = w_out * stride + k; + if (w_in < W_in) { + sum += pSrcA[c_in * W_in + w_in] * + pSrcB[c_out * C_in * K + c_in * K + k]; + } + } + } + if (has_bias) { + sum += pSrcBias[c_out]; + } + pDstC[c_out * W_out + w_out] = sum; + } + } +} \ No newline at end of file diff --git a/TargetLibraries/Generic/src/MaxPool1D_fp32.c b/TargetLibraries/Generic/src/MaxPool1D_fp32.c new file mode 100644 index 000000000..a8686503b --- /dev/null +++ b/TargetLibraries/Generic/src/MaxPool1D_fp32.c @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +// +// SPDX-License-Identifier: Apache-2.0 + +#include "DeeployBasicMath.h" +#include + +void MaxPool1d_fp32_fp32(float32_t const *__restrict__ pSrcA, uint32_t C, + uint32_t W, uint32_t K, uint32_t S, + float32_t *__restrict__ pDstC) { + uint32_t W_out = (W - K) / S + 1; + for (uint32_t c = 0; c < C; ++c) { + for (uint32_t w_out = 0; w_out < W_out; ++w_out) { + float32_t max = -INFINITY; + for (uint32_t k = 0; k < K; ++k) { + uint32_t w_in = w_out * S + k; + if (w_in >= W) + continue; + float32_t tmp = pSrcA[c * W + w_in]; + if (tmp > max) { + max = tmp; + } + } + pDstC[c * W_out + w_out] = max; + } + } +}