Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symintify max_B and max_D #3807

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ enum SSDTensor {
const int64_t /*info_B_mask_int64*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and output
const Tensor& /*vbe_output_offsets_feature_rank*/, // for reshaping vbe cpu output
const int64_t /*max_B_int*/, // for reshaping vbe cpu offsets
const c10::SymInt /*max_B*/, // for reshaping vbe cpu offsets
{%- endif %}
{%- if is_gwd %}
const Tensor& /*prev_iter_dev*/,
Expand Down Expand Up @@ -169,7 +169,7 @@ enum SSDTensor {
info_B_mask_int64,
vbe_B_offsets_rank_per_feature_, // for reshaping vbe cpu offsets and output
vbe_output_offsets_feature_rank_, // for reshaping vbe cpu output
max_B_int, // for reshaping vbe cpu offsets
max_B_, // for reshaping vbe cpu offsets
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
prev_iter_dev_,
Expand Down Expand Up @@ -247,7 +247,7 @@ enum SSDTensor {
const Tensor& /*vbe_row_output_offsets*/,
const Tensor& /*vbe_b_t_map*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and grad output
const int64_t /*max_B*/, // for reshaping vbe cpu offsets
const c10::SymInt /*max_B*/, // for reshaping vbe cpu offsets
{%- endif %}
const bool /*use_uniq_cache_locations_bwd*/,
const bool /*use_homogeneous_placements*/,
Expand Down Expand Up @@ -695,7 +695,6 @@ class {{ autograd_func }} :
const auto info_B_mask = static_cast<uint32_t>(aux_int[IDX_INFO_B_MASK]);

{%- if vbe %}
const int64_t max_B_int = max_B_.guard_int(__FILE__, __LINE__); // for reshaping vbe cpu offsets and grad_output
static auto generate_vbe_metadata_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::generate_vbe_metadata", "")
Expand Down Expand Up @@ -817,7 +816,7 @@ class {{ autograd_func }} :
ctx->saved_data["output_dtype"] = output_dtype;
{%- endif %}
{%- if vbe %}
ctx->saved_data["max_B"] = max_B_int; // for reshaping vbe cpu offsets and grad_output
ctx->saved_data["max_B"] = max_B_; // for reshaping vbe cpu offsets and grad_output
{%- endif %}

{%- if not dense %}
Expand Down Expand Up @@ -921,7 +920,7 @@ static torch::autograd::variable_list backward(
{%- endfor %}

{%- if not nobag %}
auto max_D = ctx->saved_data["max_D"].toInt();
auto max_D = ctx->saved_data["max_D"].toSymInt();
const auto mixed_D = ctx->saved_data["mixed_D"].toBool();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
{%- else %}
Expand Down Expand Up @@ -952,7 +951,7 @@ static torch::autograd::variable_list backward(
{%- endif %}
{%- if not dense %}
{%- if vbe %}
auto max_B = ctx->saved_data["max_B"].toInt(); // for reshaping vbe cpu offsets and grad_output
auto max_B = ctx->saved_data["max_B"].toSymInt(); // for reshaping vbe cpu offsets and grad_output
{%- endif %}

{%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %}
Expand Down Expand Up @@ -1013,7 +1012,7 @@ static torch::autograd::variable_list backward(
const Tensor& /*weights_offsets*/,
{%- endif %}
const Tensor& /*D_offsets*/,
const int64_t /*max_D*/,
const c10::SymInt /*max_D*/,
const Tensor& /*indices*/,
const Tensor& /*offsets*/,
{%- if ssd %}
Expand All @@ -1028,7 +1027,7 @@ static torch::autograd::variable_list backward(
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu grad_output
const int64_t /*max_B*/ // for reshaping vbe cpu offsets and grad_output
const c10::SymInt /*max_B*/ // for reshaping vbe cpu offsets and grad_output
{%- else %}
const Tensor& /*feature_requires_grad*/
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,25 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B
const c10::SymInt max_B
{%- else %}
const Tensor& feature_requires_grad
{%- endif %}
) {
{%- if vbe %}
Tensor offsets_;
const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__);
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_grad_indices", [&]() {
offsets_ = reshape_vbe_offsets<index_t>(
offsets,
vbe_B_offsets_rank_per_feature,
max_B,
max_B_int,
D_offsets.numel() - 1
);
});
const auto grad_output_ = reshape_vbe_output(
grad_output,
max_B,
max_B_int,
vbe_B_offsets_rank_per_feature,
D_offsets
);
Expand Down Expand Up @@ -128,14 +129,15 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const Tensor& vbe_output_offsets_feature_rank,
const int64_t max_B,
const c10::SymInt max_B,
{%- endif %}
const bool /*is_experimental = false*/,
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
Tensor offsets_;
{%- if vbe %}
const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__);
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_forward", [&]() {
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B_int, D_offsets.numel() - 1);
});
{%- endif %}
static auto op =
Expand Down Expand Up @@ -206,7 +208,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
const Tensor& weights_placements,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t max_D,
const c10::SymInt max_D,
const bool mixed_D,
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
Expand All @@ -225,7 +227,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B,
const c10::SymInt max_B,
{%- endif %}
const bool /*use_uniq_cache_locations*/,
const bool /*use_homogeneous_placements*/,
Expand All @@ -235,11 +237,12 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
{%- endif %})
{
{%- if vbe %}
const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__);
Tensor offsets_;
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "reshape_vbe_offsets_cpu_backward", [&]() {
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
offsets_ = reshape_vbe_offsets<index_t>(offsets, vbe_B_offsets_rank_per_feature, max_B_int, D_offsets.numel() - 1);
});
const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
const auto grad_output_ = reshape_vbe_output(grad_output, max_B_int, vbe_B_offsets_rank_per_feature, D_offsets);
{%- endif %}
{%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(
optimizer
Expand Down Expand Up @@ -276,7 +279,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
weights_placements,
weights_offsets,
D_offsets,
max_D,
max_D.guard_int(__FILE__, __LINE__),
hash_size_cumsum,
total_hash_size_bits,
indices,
Expand Down Expand Up @@ -336,7 +339,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" Tensor vbe_output_offsets_feature_rank, "
" int max_B, "
" SymInt max_B, "
{%- endif %}
" bool is_experimental, "
" int output_dtype "
Expand Down Expand Up @@ -390,7 +393,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B, "
" SymInt max_B, "
{%- endif %}
" bool use_uniq_cache_locations, "
" bool use_homogeneous_placements,"
Expand Down Expand Up @@ -429,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" int info_B_num_bits, "
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B "
" SymInt max_B "
{%- else %}
" Tensor feature_requires_grad"
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const Tensor& vbe_output_offsets_feature_rank,
const int64_t max_B,
const c10::SymInt max_B,
{%- endif %}
{%- if is_gwd %}
const Tensor& prev_iter_dev,
Expand Down Expand Up @@ -245,7 +245,7 @@ Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B,
const c10::SymInt max_B,
{%- endif %}
const bool use_uniq_cache_locations,
const bool use_homogeneous_placements,
Expand Down Expand Up @@ -410,7 +410,7 @@ Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_{{ d
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B
const c10::SymInt max_B
{%- else %}
const Tensor& feature_requires_grad
{%- endif %}
Expand Down Expand Up @@ -538,7 +538,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" Tensor vbe_output_offsets_feature_rank, "
" int max_B, "
" SymInt max_B, "
{%- endif %}
{%- if is_gwd %}
" Tensor{{ schema_annotation['prev_iter_dev'] }} prev_iter_dev, "
Expand Down Expand Up @@ -610,7 +610,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B, "
" SymInt max_B, "
{%- endif %}
" bool use_uniq_cache_locations, "
" bool use_homogeneous_placements,"
Expand Down Expand Up @@ -670,7 +670,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" int info_B_num_bits, "
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B "
" SymInt max_B "
{%- else %}
" Tensor feature_requires_grad"
{%- endif %}
Expand Down
Loading