|
11 | 11 | int8_t, int32_t, uint8_t |
12 | 12 | from Deeploy.DeeployTypes import CodeTransformation, NodeBinding |
13 | 13 | from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration |
14 | | -from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, ConvTemplate, DebugPrintTemplate, \ |
15 | | - DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, FloatConvTemplate, FloatDivTemplate, \ |
16 | | - FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \ |
17 | | - FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, FloatReduceMeanTemplate, FloatReluTemplate, \ |
18 | | - FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, \ |
19 | | - MatMulTemplate, MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \ |
20 | | - RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \ |
21 | | - iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate |
22 | | -from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DebugPrintChecker, \ |
23 | | - DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, LayerNormChecker, \ |
24 | | - MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, ReduceSumChecker, \ |
25 | | - ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, SoftmaxChecker, \ |
26 | | - TransposeChecker |
| 14 | +from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \ |
| 15 | + ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \ |
| 16 | + FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \ |
| 17 | + FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \ |
| 18 | + FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \ |
| 19 | + IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \ |
| 20 | + PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \ |
| 21 | + RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \ |
| 22 | + iRMSNormTemplate, iSoftmaxTemplate |
| 23 | +from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \ |
| 24 | + DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \ |
| 25 | + LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \ |
| 26 | + ReduceSumChecker, ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, \ |
| 27 | + SoftmaxChecker, TransposeChecker |
27 | 28 |
|
28 | 29 | BasicTransformer = CodeTransformation([ArgumentStructGeneration(), MemoryManagementGeneration(), FutureGeneration()]) |
29 | 30 |
|
|
53 | 54 | FloatAddTemplate.referenceTemplate, BasicTransformer) |
54 | 55 | ] |
55 | 56 |
|
56 | | -BasicConv1DBinding = NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]), |
57 | | - ConvTemplate.reference1DTemplate, BasicTransformer) |
| 57 | +BasicConv1DBindings = [ |
| 58 | + NodeBinding(ConvChecker( |
| 59 | + [PointerClass(type), PointerClass(type), PointerClass(type)], [PointerClass(type)]), |
| 60 | + FloatConvTemplate.reference1DTemplate, BasicTransformer) for type in FloatDataTypes |
| 61 | +] + [ |
| 62 | + NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]), |
| 63 | + ConvTemplate.reference1DTemplate, BasicTransformer) |
| 64 | +] |
58 | 65 |
|
59 | 66 | BasicDWConv1DBinding = NodeBinding(ConvChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]), |
60 | 67 | DWConvTemplate.reference1DTemplate, BasicTransformer) |
|
147 | 154 | FloatMatMulTemplate.referenceTemplate, BasicTransformer) |
148 | 155 | ] |
149 | 156 |
|
| 157 | +BasicMaxPool1DBindings = [ |
| 158 | + NodeBinding(MaxPoolChecker([PointerClass(type)], [PointerClass(type)]), FloatMaxPoolTemplate.reference1DTemplate, |
| 159 | + BasicTransformer) for type in FloatDataTypes |
| 160 | +] |
| 161 | + |
150 | 162 | BasicMaxPool2DBindings = [ |
151 | 163 | NodeBinding(MaxPoolChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), MaxPoolTemplate.referenceTemplate, |
152 | 164 | BasicTransformer) |
|
167 | 179 | BasicPad1DBindings = [ |
168 | 180 | NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), PadTemplate.reference1DTemplate, |
169 | 181 | BasicTransformer) for type in SignedIntegerDataTypes |
| 182 | +] + [ |
| 183 | + NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), FloatPadTemplate.reference1DTemplate, |
| 184 | + BasicTransformer) for type in FloatDataTypes |
170 | 185 | ] |
| 186 | + |
171 | 187 | BasicPad2DBindings = [ |
172 | 188 | NodeBinding(PadChecker([PointerClass(type)], [PointerClass(type)]), PadTemplate.reference2DTemplate, |
173 | 189 | BasicTransformer) for type in SignedIntegerDataTypes |
|
266 | 282 | NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, |
267 | 283 | BasicTransformer), |
268 | 284 | ] |
| 285 | + |
| 286 | +BasicBatchNormBindings = [ |
| 287 | + NodeBinding( |
| 288 | + BatchNormChecker( |
| 289 | + [PointerClass(type), |
| 290 | + PointerClass(type), |
| 291 | + PointerClass(type), |
| 292 | + PointerClass(type), |
| 293 | + PointerClass(type)], [PointerClass(type)]), BatchNormalizationTemplate.referenceTemplate, BasicTransformer) |
| 294 | + for type in FloatDataTypes |
| 295 | +] |
| 296 | + |
| 297 | +BasicConvTransposeBindings = [ |
| 298 | + NodeBinding( |
| 299 | + ConvChecker( |
| 300 | + [PointerClass(type), PointerClass(type), PointerClass(type)], # input, weight, bias |
| 301 | + [PointerClass(type)]), |
| 302 | + ConvTransposeTemplate.referenceTemplate, |
| 303 | + BasicTransformer) for type in FloatDataTypes |
| 304 | +] + [ |
| 305 | + NodeBinding( |
| 306 | + ConvChecker( |
| 307 | + [PointerClass(type), PointerClass(type)], # input, weight |
| 308 | + [PointerClass(type)]), |
| 309 | + ConvTransposeTemplate.referenceTemplate, |
| 310 | + BasicTransformer) for type in FloatDataTypes |
| 311 | +] |
0 commit comments