Skip to content

Commit ba25044

Browse files
spcypptfacebook-github-bot
authored andcommitted
Symintify max_B and max_D (#3807)
Summary: Pull Request resolved: #3807 X-link: facebookresearch/FBGEMM#891 Symintify `max_B` and `max_D` to fix graph break https://fburl.com/5hhjzv4h Reviewed By: q10, sryap, nautsimon Differential Revision: D71018173 fbshipit-source-id: 4f30e7c73e19a42299236eb27fd192b6c3a51545
1 parent 22c9f3a commit ba25044

3 files changed

+30
-28
lines changed

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp

+8-9
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ enum SSDTensor {
122122
const int64_t /*info_B_mask_int64*/,
123123
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and output
124124
const Tensor& /*vbe_output_offsets_feature_rank*/, // for reshaping vbe cpu output
125-
const int64_t /*max_B_int*/, // for reshaping vbe cpu offsets
125+
const c10::SymInt /*max_B*/, // for reshaping vbe cpu offsets
126126
{%- endif %}
127127
{%- if is_gwd %}
128128
const Tensor& /*prev_iter_dev*/,
@@ -169,7 +169,7 @@ enum SSDTensor {
169169
info_B_mask_int64,
170170
vbe_B_offsets_rank_per_feature_, // for reshaping vbe cpu offsets and output
171171
vbe_output_offsets_feature_rank_, // for reshaping vbe cpu output
172-
max_B_int, // for reshaping vbe cpu offsets
172+
max_B_, // for reshaping vbe cpu offsets
173173
{%- endif %} {# /* if vbe */ #}
174174
{%- if is_gwd %}
175175
prev_iter_dev_,
@@ -247,7 +247,7 @@ enum SSDTensor {
247247
const Tensor& /*vbe_row_output_offsets*/,
248248
const Tensor& /*vbe_b_t_map*/,
249249
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and grad output
250-
const int64_t /*max_B*/, // for reshaping vbe cpu offsets
250+
const c10::SymInt /*max_B*/, // for reshaping vbe cpu offsets
251251
{%- endif %}
252252
const bool /*use_uniq_cache_locations_bwd*/,
253253
const bool /*use_homogeneous_placements*/,
@@ -695,7 +695,6 @@ class {{ autograd_func }} :
695695
const auto info_B_mask = static_cast<uint32_t>(aux_int[IDX_INFO_B_MASK]);
696696

697697
{%- if vbe %}
698-
const int64_t max_B_int = max_B_.guard_int(__FILE__, __LINE__); // for reshaping vbe cpu offsets and grad_output
699698
static auto generate_vbe_metadata_op =
700699
torch::Dispatcher::singleton()
701700
.findSchemaOrThrow("fbgemm::generate_vbe_metadata", "")
@@ -817,7 +816,7 @@ class {{ autograd_func }} :
817816
ctx->saved_data["output_dtype"] = output_dtype;
818817
{%- endif %}
819818
{%- if vbe %}
820-
ctx->saved_data["max_B"] = max_B_int; // for reshaping vbe cpu offsets and grad_output
819+
ctx->saved_data["max_B"] = max_B_; // for reshaping vbe cpu offsets and grad_output
821820
{%- endif %}
822821

823822
{%- if not dense %}
@@ -921,7 +920,7 @@ static torch::autograd::variable_list backward(
921920
{%- endfor %}
922921

923922
{%- if not nobag %}
924-
auto max_D = ctx->saved_data["max_D"].toInt();
923+
auto max_D = ctx->saved_data["max_D"].toSymInt();
925924
const auto mixed_D = ctx->saved_data["mixed_D"].toBool();
926925
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
927926
{%- else %}
@@ -952,7 +951,7 @@ static torch::autograd::variable_list backward(
952951
{%- endif %}
953952
{%- if not dense %}
954953
{%- if vbe %}
955-
auto max_B = ctx->saved_data["max_B"].toInt(); // for reshaping vbe cpu offsets and grad_output
954+
auto max_B = ctx->saved_data["max_B"].toSymInt(); // for reshaping vbe cpu offsets and grad_output
956955
{%- endif %}
957956

958957
{%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %}
@@ -1013,7 +1012,7 @@ static torch::autograd::variable_list backward(
10131012
const Tensor& /*weights_offsets*/,
10141013
{%- endif %}
10151014
const Tensor& /*D_offsets*/,
1016-
const int64_t /*max_D*/,
1015+
const c10::SymInt /*max_D*/,
10171016
const Tensor& /*indices*/,
10181017
const Tensor& /*offsets*/,
10191018
{%- if ssd %}
@@ -1028,7 +1027,7 @@ static torch::autograd::variable_list backward(
10281027
const int64_t /*info_B_num_bits*/,
10291028
const int64_t /*info_B_mask_int64*/,
10301029
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu grad_output
1031-
const int64_t /*max_B*/ // for reshaping vbe cpu offsets and grad_output
1030+
const c10::SymInt /*max_B*/ // for reshaping vbe cpu offsets and grad_output
10321031
{%- else %}
10331032
const Tensor& /*feature_requires_grad*/
10341033
{%- endif %}

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,25 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
6161
const int64_t info_B_num_bits,
6262
const int64_t info_B_mask_int64,
6363
const Tensor& vbe_B_offsets_rank_per_feature,
64-
const int64_t max_B
64+
const c10::SymInt max_B
6565
{%- else %}
6666
const Tensor& feature_requires_grad
6767
{%- endif %}
6868
) {
6969
{%- if vbe %}
7070
Tensor offsets_;
71+
const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__);
7172
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_grad_indices", [&]() {
7273
offsets_ = reshape_vbe_offsets<index_t>(
7374
offsets,
7475
vbe_B_offsets_rank_per_feature,
75-
max_B,
76+
max_B_int,
7677
D_offsets.numel() - 1
7778
);
7879
});
7980
const auto grad_output_ = reshape_vbe_output(
8081
grad_output,
81-
max_B,
82+
max_B_int,
8283
vbe_B_offsets_rank_per_feature,
8384
D_offsets
8485
);
@@ -128,14 +129,15 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
128129
const int64_t info_B_mask_int64,
129130
const Tensor& vbe_B_offsets_rank_per_feature,
130131
const Tensor& vbe_output_offsets_feature_rank,
131-
const int64_t max_B,
132+
const c10::SymInt max_B,
132133
{%- endif %}
133134
const bool /*is_experimental = false*/,
134135
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
135136
Tensor offsets_;
136137
{%- if vbe %}
138+
const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__);
137139
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_forward", [&]() {
138-
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
140+
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B_int, D_offsets.numel() - 1);
139141
});
140142
{%- endif %}
141143
static auto op =
@@ -206,7 +208,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
206208
const Tensor& weights_placements,
207209
const Tensor& weights_offsets,
208210
const Tensor& D_offsets,
209-
const int64_t max_D,
211+
const c10::SymInt max_D,
210212
const bool mixed_D,
211213
const Tensor& hash_size_cumsum,
212214
const int64_t total_hash_size_bits,
@@ -225,7 +227,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
225227
const Tensor& vbe_row_output_offsets,
226228
const Tensor& vbe_b_t_map,
227229
const Tensor& vbe_B_offsets_rank_per_feature,
228-
const int64_t max_B,
230+
const c10::SymInt max_B,
229231
{%- endif %}
230232
const bool /*use_uniq_cache_locations*/,
231233
const bool /*use_homogeneous_placements*/,
@@ -235,11 +237,12 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
235237
{%- endif %})
236238
{
237239
{%- if vbe %}
240+
const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__);
238241
Tensor offsets_;
239242
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_backward", [&]() {
240-
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
243+
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B_int, D_offsets.numel() - 1);
241244
});
242-
const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
245+
const auto grad_output_ = reshape_vbe_output(grad_output, max_B_int, vbe_B_offsets_rank_per_feature, D_offsets);
243246
{%- endif %}
244247
{%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(
245248
optimizer
@@ -276,7 +279,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
276279
weights_placements,
277280
weights_offsets,
278281
D_offsets,
279-
max_D,
282+
max_D.guard_int(__FILE__, __LINE__),
280283
hash_size_cumsum,
281284
total_hash_size_bits,
282285
indices,
@@ -336,7 +339,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
336339
" int info_B_mask_int64, "
337340
" Tensor vbe_B_offsets_rank_per_feature, "
338341
" Tensor vbe_output_offsets_feature_rank, "
339-
" int max_B, "
342+
" SymInt max_B, "
340343
{%- endif %}
341344
" bool is_experimental, "
342345
" int output_dtype "
@@ -390,7 +393,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
390393
" Tensor vbe_row_output_offsets, "
391394
" Tensor vbe_b_t_map, "
392395
" Tensor vbe_B_offsets_rank_per_feature, "
393-
" int max_B, "
396+
" SymInt max_B, "
394397
{%- endif %}
395398
" bool use_uniq_cache_locations, "
396399
" bool use_homogeneous_placements,"
@@ -429,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
429432
" int info_B_num_bits, "
430433
" int info_B_mask_int64, "
431434
" Tensor vbe_B_offsets_rank_per_feature, "
432-
" int max_B "
435+
" SymInt max_B "
433436
{%- else %}
434437
" Tensor feature_requires_grad"
435438
{%- endif %}

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt
9595
const int64_t info_B_mask_int64,
9696
const Tensor& vbe_B_offsets_rank_per_feature,
9797
const Tensor& vbe_output_offsets_feature_rank,
98-
const int64_t max_B,
98+
const c10::SymInt max_B,
9999
{%- endif %}
100100
{%- if is_gwd %}
101101
const Tensor& prev_iter_dev,
@@ -245,7 +245,7 @@ Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{
245245
const Tensor& vbe_row_output_offsets,
246246
const Tensor& vbe_b_t_map,
247247
const Tensor& vbe_B_offsets_rank_per_feature,
248-
const int64_t max_B,
248+
const c10::SymInt max_B,
249249
{%- endif %}
250250
const bool use_uniq_cache_locations,
251251
const bool use_homogeneous_placements,
@@ -410,7 +410,7 @@ Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_{{ d
410410
const int64_t info_B_num_bits,
411411
const int64_t info_B_mask_int64,
412412
const Tensor& vbe_B_offsets_rank_per_feature,
413-
const int64_t max_B
413+
const c10::SymInt max_B
414414
{%- else %}
415415
const Tensor& feature_requires_grad
416416
{%- endif %}
@@ -538,7 +538,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
538538
" int info_B_mask_int64, "
539539
" Tensor vbe_B_offsets_rank_per_feature, "
540540
" Tensor vbe_output_offsets_feature_rank, "
541-
" int max_B, "
541+
" SymInt max_B, "
542542
{%- endif %}
543543
{%- if is_gwd %}
544544
" Tensor{{ schema_annotation['prev_iter_dev'] }} prev_iter_dev, "
@@ -610,7 +610,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
610610
" Tensor vbe_row_output_offsets, "
611611
" Tensor vbe_b_t_map, "
612612
" Tensor vbe_B_offsets_rank_per_feature, "
613-
" int max_B, "
613+
" SymInt max_B, "
614614
{%- endif %}
615615
" bool use_uniq_cache_locations, "
616616
" bool use_homogeneous_placements,"
@@ -670,7 +670,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
670670
" int info_B_num_bits, "
671671
" int info_B_mask_int64, "
672672
" Tensor vbe_B_offsets_rank_per_feature, "
673-
" int max_B "
673+
" SymInt max_B "
674674
{%- else %}
675675
" Tensor feature_requires_grad"
676676
{%- endif %}

0 commit comments

Comments
 (0)