@@ -42,12 +42,11 @@ def __init__(self, templateStr):
4242 @staticmethod
4343 def computeTransientBuffersSize (
4444 ctxt : NetworkContext ,
45- operatorRepresentation : OperatorRepresentation ) -> List [Tuple [str , Union [ int , IntVar ] ]]:
45+ operatorRepresentation : OperatorRepresentation ) -> List [Tuple [str , str ]]:
4646
4747 # Memory allocation for the im2col buffer can be dynamic, based on the number of cores.
4848 # WARNING: This works because value is only used as string, in the allocate template.
49- # TODO: This should work as NUM_CORES * P * Q, but it raises an error if double the memory is not allocated.
50- im2col_dim = "2 * NUM_CORES * " + str (
49+ im2col_dim = "NUM_CORES * " + str (
5150 operatorRepresentation ['dim_kernel_x' ] * operatorRepresentation ['dim_kernel_y' ])
5251 im2col_name = operatorRepresentation ['nodeName' ] + "_buffer"
5352 return [(im2col_name , im2col_dim )]
@@ -58,6 +57,9 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
5857 ctxt , operatorRepresentation )[0 ]
5958 ctxt .hoistTransientBuffer (im2col_name , im2col_dim )
6059
60+ # Manually set the type of the im2col buffer to match the input type, since it defaults to void for transient buffers
61+ ctxt .lookup (im2col_name )._type .referencedType = ctxt .lookup (operatorRepresentation ['data_in' ])._type .referencedType
62+
6163 operatorRepresentation ['ctxtBuffer' ] = im2col_name
6264 operatorRepresentation ['ctxtBufferSize' ] = im2col_dim
6365 return ctxt , operatorRepresentation , [im2col_name ]
@@ -107,11 +109,11 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
107109 ${stride_y},
108110 ${bias}, ${has_bias},
109111 ref_${data_out}_${data_out},
110- ${padding_y_top},
111- ${padding_y_bottom},
112- ${padding_x_left},
113- ${padding_x_right},
114- ${ctxtBuffer}
112+ ${padding_y_top},
113+ ${padding_y_bottom},
114+ ${padding_x_left},
115+ ${padding_x_right},
116+ ${ctxtBuffer}
115117 );
116118
117119 ref_${data_out}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y};
@@ -139,11 +141,11 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
139141 ${stride_y},
140142 ${bias}, ${has_bias},
141143 ref_${data_out}_${data_out},
142- ${padding_y_top},
143- ${padding_y_bottom},
144- ${padding_x_left},
145- ${padding_x_right},
146- ${ctxtBuffer}
144+ ${padding_y_top},
145+ ${padding_y_bottom},
146+ ${padding_x_left},
147+ ${padding_x_right},
148+ ${ctxtBuffer}
147149 );
148150
149151 ref_${data_out}_${data_in} += ${ch_im_in} * ${dim_im_in_x} * ${dim_im_in_y};
0 commit comments