@@ -61,24 +61,25 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
61
61
const int64_t info_B_num_bits,
62
62
const int64_t info_B_mask_int64,
63
63
const Tensor& vbe_B_offsets_rank_per_feature,
64
- const int64_t max_B
64
+ const c10::SymInt max_B
65
65
{%- else %}
66
66
const Tensor& feature_requires_grad
67
67
{%- endif %}
68
68
) {
69
69
{%- if vbe %}
70
70
Tensor offsets_;
71
+ const int64_t max_B_int = max_B.guard_int (__FILE__, __LINE__);
71
72
AT_DISPATCH_INDEX_TYPES (offsets.scalar_type (), " reshape_vbe_offsets_cpu_grad_indices" , [&]() {
72
73
offsets_ = reshape_vbe_offsets<index_t >(
73
74
offsets,
74
75
vbe_B_offsets_rank_per_feature,
75
- max_B ,
76
+ max_B_int ,
76
77
D_offsets.numel () - 1
77
78
);
78
79
});
79
80
const auto grad_output_ = reshape_vbe_output (
80
81
grad_output,
81
- max_B ,
82
+ max_B_int ,
82
83
vbe_B_offsets_rank_per_feature,
83
84
D_offsets
84
85
);
@@ -128,14 +129,15 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
128
129
const int64_t info_B_mask_int64,
129
130
const Tensor& vbe_B_offsets_rank_per_feature,
130
131
const Tensor& vbe_output_offsets_feature_rank,
131
- const int64_t max_B,
132
+ const c10::SymInt max_B,
132
133
{%- endif %}
133
134
const bool /* is_experimental = false*/ ,
134
135
const int64_t output_dtype = static_cast <int64_t >(SparseType::FP32)) {
135
136
Tensor offsets_;
136
137
{%- if vbe %}
138
+ const int64_t max_B_int = max_B.guard_int (__FILE__, __LINE__);
137
139
AT_DISPATCH_INDEX_TYPES (offsets.scalar_type (), " reshape_vbe_offsets_cpu_forward" , [&]() {
138
- offsets_ = reshape_vbe_offsets<index_t >(offsets, vbe_B_offsets_rank_per_feature, max_B , D_offsets.numel () - 1 );
140
+ offsets_ = reshape_vbe_offsets<index_t >(offsets, vbe_B_offsets_rank_per_feature, max_B_int , D_offsets.numel () - 1 );
139
141
});
140
142
{%- endif %}
141
143
static auto op =
@@ -206,7 +208,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
206
208
const Tensor& weights_placements,
207
209
const Tensor& weights_offsets,
208
210
const Tensor& D_offsets,
209
- const int64_t max_D,
211
+ const c10::SymInt max_D,
210
212
const bool mixed_D,
211
213
const Tensor& hash_size_cumsum,
212
214
const int64_t total_hash_size_bits,
@@ -225,7 +227,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
225
227
const Tensor& vbe_row_output_offsets,
226
228
const Tensor& vbe_b_t_map,
227
229
const Tensor& vbe_B_offsets_rank_per_feature,
228
- const int64_t max_B,
230
+ const c10::SymInt max_B,
229
231
{%- endif %}
230
232
const bool /* use_uniq_cache_locations*/ ,
231
233
const bool /* use_homogeneous_placements*/ ,
@@ -235,11 +237,12 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
235
237
{%- endif %})
236
238
{
237
239
{%- if vbe %}
240
+ const int64_t max_B_int = max_B.guard_int (__FILE__, __LINE__);
238
241
Tensor offsets_;
239
242
AT_DISPATCH_INDEX_TYPES (offsets.scalar_type (), " reshape_vbe_offsets_cpu_backward" , [&]() {
240
- offsets_ = reshape_vbe_offsets<index_t >(offsets, vbe_B_offsets_rank_per_feature, max_B , D_offsets.numel () - 1 );
243
+ offsets_ = reshape_vbe_offsets<index_t >(offsets, vbe_B_offsets_rank_per_feature, max_B_int , D_offsets.numel () - 1 );
241
244
});
242
- const auto grad_output_ = reshape_vbe_output (grad_output, max_B , vbe_B_offsets_rank_per_feature, D_offsets);
245
+ const auto grad_output_ = reshape_vbe_output (grad_output, max_B_int , vbe_B_offsets_rank_per_feature, D_offsets);
243
246
{%- endif %}
244
247
{%- set backward_op = " split_embedding_backward_codegen_{}_cpu" .format (
245
248
optimizer
@@ -276,7 +279,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
276
279
weights_placements,
277
280
weights_offsets,
278
281
D_offsets,
279
- max_D,
282
+ max_D. guard_int (__FILE__, __LINE__) ,
280
283
hash_size_cumsum,
281
284
total_hash_size_bits,
282
285
indices,
@@ -336,7 +339,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
336
339
" int info_B_mask_int64, "
337
340
" Tensor vbe_B_offsets_rank_per_feature, "
338
341
" Tensor vbe_output_offsets_feature_rank, "
339
- " int max_B, "
342
+ " SymInt max_B, "
340
343
{%- endif %}
341
344
" bool is_experimental, "
342
345
" int output_dtype "
@@ -390,7 +393,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
390
393
" Tensor vbe_row_output_offsets, "
391
394
" Tensor vbe_b_t_map, "
392
395
" Tensor vbe_B_offsets_rank_per_feature, "
393
- " int max_B, "
396
+ " SymInt max_B, "
394
397
{%- endif %}
395
398
" bool use_uniq_cache_locations, "
396
399
" bool use_homogeneous_placements,"
@@ -429,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
429
432
" int info_B_num_bits, "
430
433
" int info_B_mask_int64, "
431
434
" Tensor vbe_B_offsets_rank_per_feature, "
432
- " int max_B "
435
+ " SymInt max_B "
433
436
{%- else %}
434
437
" Tensor feature_requires_grad"
435
438
{%- endif %}
0 commit comments