Skip to content

Commit 2b46d2d

Browse files
committed
AttentionTraining Support
1 parent 7f99f2c commit 2b46d2d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+365
-37
lines changed

Deeploy/Targets/Generic/Layers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ class GELUGradLayer(ONNXLayer):
8484

8585
def __init__(self, maps: List[NodeMapper]):
8686
super().__init__(maps)
87+
88+
def computeOps(self):
89+
size = self.mapper.parser.operatorRepresentation['size']
90+
ops_per_element = 9
91+
gelu_grad_ops = size * ops_per_element
92+
return gelu_grad_ops
93+
8794

8895
class iHardswishLayer(ONNXLayer):
8996

@@ -490,6 +497,11 @@ class SGDLayer(ONNXLayer):
490497
def __init__(self, maps: List[NodeMapper]):
491498
super().__init__(maps)
492499

500+
def computeOps(self):
501+
502+
size = self.mapper.parser.operatorRepresentation['size']
503+
return size*2
504+
493505

494506
class LinearAttentionLayer(ONNXLayer):
495507

Deeploy/Targets/Generic/Parsers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def __init__(self):
991991

992992
def parseNode(self, node: gs.Node) -> (bool):
993993

994-
ret = all([len(node.inputs) == 1, len(node.outputs) == 1])
994+
ret = all([len(node.inputs) >= 1, len(node.outputs) == 1])
995995

996996
return ret
997997

Deeploy/Targets/PULPOpen/Bindings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@
180180
] + [
181181
NodeBinding(AddChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
182182
FloatAddTemplate.referenceTemplate, ForkTransformer)
183-
]
183+
] + [
184+
NodeBinding(AddChecker([PointerClass(float32_t), PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
185+
FloatAddTemplate.referenceTemplate, ForkTransformer)
186+
]
184187

185188
PULPRQSConv2DBindings = [
186189
NodeBinding(

Deeploy/Targets/PULPOpen/Templates/SGDTemplate.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,45 @@
2626
from Deeploy.DeeployTypes import NodeTemplate
2727

2828
referenceTemplate = NodeTemplate("""
29-
// SGD Weight Update (Name: ${nodeName}, Op: ${nodeOp})
30-
BEGIN_SINGLE_CORE
31-
${weight_type.typeName} ref_${weight} = ${weight};
32-
${grad_type.typeName} ref_${grad} = ${grad};
33-
${weight_type.typeName} ref_${weight_updated} = ${weight_updated};
29+
// SGD Weight Update with Separated Multiplication and Subtraction Unrolling
30+
// (Name: ${nodeName}, Op: ${nodeOp})
31+
int8_t ${nodeName}_core_id = pi_core_id();
32+
int8_t ${nodeName}_log2Core = log2(NUM_CORES);
33+
int16_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0);
34+
int16_t ${nodeName}_chunk_start = MIN(${nodeName}_chunk*${nodeName}_core_id, ${size});
35+
int16_t ${nodeName}_chunk_stop = MIN(${nodeName}_chunk_start + ${nodeName}_chunk, ${size});
36+
37+
${weight_type.typeName} ref_${weight} = ${weight};
38+
${grad_type.typeName} ref_${grad} = ${grad};
39+
${weight_type.typeName} ref_${weight_updated} = ${weight_updated};
40+
41+
float32_t learning_rate = ${lr};
42+
43+
// Temporary buffer for multiplication results
44+
float32_t temp_mul[6];
45+
46+
uint32_t i = ${nodeName}_chunk_start;
47+
for (; i+5 < ${nodeName}_chunk_stop; i+=6) {
48+
// Unrolled multiplication operations
49+
temp_mul[0] = learning_rate * ref_${grad}[i];
50+
temp_mul[1] = learning_rate * ref_${grad}[i+1];
51+
temp_mul[2] = learning_rate * ref_${grad}[i+2];
52+
temp_mul[3] = learning_rate * ref_${grad}[i+3];
53+
temp_mul[4] = learning_rate * ref_${grad}[i+4];
54+
temp_mul[5] = learning_rate * ref_${grad}[i+5];
3455
35-
float32_t learning_rate = ${lr};
56+
// Unrolled subtraction operations
57+
ref_${weight_updated}[i] = ref_${weight}[i] - temp_mul[0];
58+
ref_${weight_updated}[i+1] = ref_${weight}[i+1] - temp_mul[1];
59+
ref_${weight_updated}[i+2] = ref_${weight}[i+2] - temp_mul[2];
60+
ref_${weight_updated}[i+3] = ref_${weight}[i+3] - temp_mul[3];
61+
ref_${weight_updated}[i+4] = ref_${weight}[i+4] - temp_mul[4];
62+
ref_${weight_updated}[i+5] = ref_${weight}[i+5] - temp_mul[5];
63+
}
3664
37-
for (uint32_t i=0; i<${size}; ++i) {
38-
ref_${weight_updated}[i] = ref_${weight}[i] - learning_rate * ref_${grad}[i];
39-
}
40-
END_SINGLE_CORE
41-
""")
65+
// Handle remaining elements
66+
for (; i < ${nodeName}_chunk_stop; i++) {
67+
float32_t temp_grad = learning_rate * ref_${grad}[i];
68+
ref_${weight_updated}[i] = ref_${weight}[i] - temp_grad;
69+
}
70+
""")
Binary file not shown.
Binary file not shown.
-1.09 MB
Binary file not shown.
-1.12 MB
Binary file not shown.
-12.5 KB
Binary file not shown.
-1.16 MB
Binary file not shown.

0 commit comments

Comments
 (0)