Skip to content

Commit dc111be

Browse files
authored
Support for 1D Autoencoder (#98)
This PR adds required kernels, binding, and templates to support 1D autoencoders. ## Added - Added support for 1D Autoencoder operations, including newly implemented BatchNorm, 1D ConvTranspose, and 1D MaxPool layers. - Added `Autoencoder1D` test to the CI.
1 parent 15c4a23 commit dc111be

25 files changed

+733
-36
lines changed

.github/workflows/ci-platform-generic.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,4 @@ jobs:
9696
CCT/CCT_1_16_16_8
9797
CCT/CCT_2_32_32_128_Opset20
9898
testFloatDemoTinyViT
99+
Autoencoder1D

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
44
## Unreleased (Planned Release Target: v0.2.1)
55

66
### List of Pull Requests
7+
- Support for 1D Autoencoder [#98](https://github.com/pulp-platform/Deeploy/pull/98)
78
- Refactor Logging for Improved Debugging [#115](https://github.com/pulp-platform/Deeploy/pull/115)
89
- Add reuse-tool as an SPDX license header linter [#113](https://github.com/pulp-platform/Deeploy/pull/113)
910
- Bug fixes, API Cleanup and Reduce Compiler Warning on PULP [#112](https://github.com/pulp-platform/Deeploy/pull/112)
@@ -158,6 +159,13 @@ This release containing major architectural changes, new platform support, enhan
158159

159160

160161
### Added
162+
- BatchNorm kernel
163+
- ConvTranspose kernel
164+
- MaxPool1D kernel
165+
- Template for 1D Convolution
166+
- Support for float32 data type in the previous kernels
167+
- Float binding for Pad1D kernel
168+
- Test for Autoencoder1D in the CI pipeline
161169
- ChimeraDeployer, currently mainly a placeholder
162170
- Allocate templates for Chimera
163171
- ChimeraPlatform, using appropriate allocation templates and using the generic Parser + Binding for the Add node
@@ -291,6 +299,8 @@ This release containing major architectural changes, new platform support, enhan
291299
- `dev-requirements.txt` tracking the dependencies of the build system, linting, documentation, and QOL.
292300

293301
### Changed
302+
- FloatConvTemplate file
303+
- Platform.py file
294304
- Bump the CMake version to 3.24 as required for the chimera-sdk
295305
- Bump GVSoC's version and add chimera simulation target
296306
- Rename the generic source util to utils to avoid name collision with chimera-sdk

Deeploy/Targets/Generic/Bindings.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,20 @@
1111
int8_t, int32_t, uint8_t
1212
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
1313
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
14-
from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, ConvTemplate, DebugPrintTemplate, \
15-
DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, FloatConvTemplate, FloatDivTemplate, \
16-
FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \
17-
FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, FloatReduceMeanTemplate, FloatReluTemplate, \
18-
FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, \
19-
MatMulTemplate, MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
20-
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
21-
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
22-
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DebugPrintChecker, \
23-
DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, LayerNormChecker, \
24-
MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, ReduceSumChecker, \
25-
ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, SoftmaxChecker, \
26-
TransposeChecker
14+
from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \
15+
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
16+
FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \
17+
FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \
18+
FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \
19+
IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \
20+
PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \
21+
RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \
22+
iRMSNormTemplate, iSoftmaxTemplate
23+
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
24+
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
25+
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
26+
ReduceSumChecker, ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, \
27+
SoftmaxChecker, TransposeChecker
2728

2829
BasicTransformer = CodeTransformation([ArgumentStructGeneration(), MemoryManagementGeneration(), FutureGeneration()])
2930

@@ -53,8 +54,14 @@
5354
FloatAddTemplate.referenceTemplate, BasicTransformer)
5455
]
5556

56-
BasicConv1DBinding = NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
57-
ConvTemplate.reference1DTemplate, BasicTransformer)
57+
BasicConv1DBindings = [
58+
NodeBinding(ConvChecker(
59+
[PointerClass(type), PointerClass(type), PointerClass(type)], [PointerClass(type)]),
60+
FloatConvTemplate.reference1DTemplate, BasicTransformer) for type in FloatDataTypes
61+
] + [
62+
NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
63+
ConvTemplate.reference1DTemplate, BasicTransformer)
64+
]
5865

5966
BasicDWConv1DBinding = NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
6067
DWConvTemplate.reference1DTemplate, BasicTransformer)
@@ -147,6 +154,11 @@
147154
FloatMatMulTemplate.referenceTemplate, BasicTransformer)
148155
]
149156

157+
BasicMaxPool1DBindings = [
158+
NodeBinding(MaxPoolChecker([PointerClass(type)], [PointerClass(type)]), FloatMaxPoolTemplate.reference1DTemplate,
159+
BasicTransformer) for type in FloatDataTypes
160+
]
161+
150162
BasicMaxPool2DBindings = [
151163
NodeBinding(MaxPoolChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), MaxPoolTemplate.referenceTemplate,
152164
BasicTransformer)
@@ -167,7 +179,11 @@
167179
BasicPad1DBindings = [
168180
NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), PadTemplate.reference1DTemplate,
169181
BasicTransformer) for type in SignedIntegerDataTypes
182+
] + [
183+
NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), FloatPadTemplate.reference1DTemplate,
184+
BasicTransformer) for type in FloatDataTypes
170185
]
186+
171187
BasicPad2DBindings = [
172188
NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), PadTemplate.reference2DTemplate,
173189
BasicTransformer) for type in SignedIntegerDataTypes
@@ -266,3 +282,30 @@
266282
NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate,
267283
BasicTransformer),
268284
]
285+
286+
BasicBatchNormBindings = [
287+
NodeBinding(
288+
BatchNormChecker(
289+
[PointerClass(type),
290+
PointerClass(type),
291+
PointerClass(type),
292+
PointerClass(type),
293+
PointerClass(type)], [PointerClass(type)]), BatchNormalizationTemplate.referenceTemplate, BasicTransformer)
294+
for type in FloatDataTypes
295+
]
296+
297+
BasicConvTransposeBindings = [
298+
NodeBinding(
299+
ConvChecker(
300+
[PointerClass(type), PointerClass(type), PointerClass(type)], # input, weight, bias
301+
[PointerClass(type)]),
302+
ConvTransposeTemplate.referenceTemplate,
303+
BasicTransformer) for type in FloatDataTypes
304+
] + [
305+
NodeBinding(
306+
ConvChecker(
307+
[PointerClass(type), PointerClass(type)], # input, weight
308+
[PointerClass(type)]),
309+
ConvTransposeTemplate.referenceTemplate,
310+
BasicTransformer) for type in FloatDataTypes
311+
]

Deeploy/Targets/Generic/Layers.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,3 +618,64 @@ class DequantLayer(ONNXLayer):
618618

619619
def __init__(self, maps: List[NodeMapper]):
620620
super().__init__(maps)
621+
622+
623+
class BatchNormalizationLayer(ONNXLayer):
624+
625+
def __init__(self, maps: List[NodeMapper]):
626+
super().__init__(maps)
627+
628+
def computeOps(self):
629+
# 5 operations per element: sub, mul, add, sqrt, div
630+
B = self.mapper.parser.operatorRepresentation['batch_size']
631+
C = self.mapper.parser.operatorRepresentation['channel_size']
632+
W = self.mapper.parser.operatorRepresentation['window_size']
633+
return B * C * W * 5
634+
635+
636+
class ConvTransposeLayer(ONNXLayer):
637+
638+
def __init__(self, maps: List[NodeMapper]):
639+
super().__init__(maps)
640+
641+
def computeShapes(self, inputShapes: Shape, outputShapes: Shape, operatorRepresentation,
642+
channels_first) -> Tuple[Shape, Shape]:
643+
"""
644+
Infers output shapes for ConvTranspose using only static info.
645+
- inputShapes[0]: input tensor shape (e.g., [N, C_in, W] for 1D, [N, C_in, H, W] for 2D)
646+
- inputShapes[1]: weight tensor shape (e.g., [C_in, C_out // group, kW] for 1D)
647+
- outputShapes[0]: output tensor shape (to be updated)
648+
"""
649+
newInputShapes = list(inputShapes)
650+
newOutputShapes = list(outputShapes)
651+
group = operatorRepresentation.get('group', 1)
652+
weight_shape = inputShapes[1]
653+
654+
if newOutputShapes and len(newOutputShapes[0]) >= 2:
655+
# For 1D: weight_shape = [C_in, C_out // group, kW]
656+
# For 2D: weight_shape = [C_in, C_out // group, kH, kW]
657+
ch_out = weight_shape[1] * group
658+
if channels_first:
659+
newOutputShapes[0][1] = ch_out
660+
else:
661+
newOutputShapes[0][-1] = ch_out
662+
663+
return newInputShapes, newOutputShapes
664+
665+
def computeOps(self):
666+
opRep = self.mapper.parser.operatorRepresentation
667+
668+
groups = opRep.get('group', 1)
669+
kernel_shape = np.prod(opRep['kernel_shape']) # es. [3, 3] -> 9
670+
ch_in = opRep['ch_im_in']
671+
ch_out = opRep['ch_im_out']
672+
673+
opsPerPx = int(kernel_shape * ch_in * ch_out / groups) * 2
674+
675+
# ConvTranspose upscales spatial dims, quindi num pixel viene da output
676+
if 'dim_im_out_y' in opRep:
677+
numPx = opRep['dim_im_out_x'] * opRep['dim_im_out_y']
678+
else:
679+
numPx = opRep['dim_im_out_x']
680+
681+
return numPx * opsPerPx

0 commit comments

Comments
 (0)