Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions oneflow/user/ops/dropout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,29 @@ limitations under the License.

namespace oneflow {

//ONEFLOW_DROPOUT_MASK_USE_BITS is true when using NPU
DEFINE_ENV_BOOL(ONEFLOW_DROPOUT_MASK_USE_BITS, false);

/* static */ Maybe<void> DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const Shape& in_shape = ctx->InputShape("in", 0);
ctx->SetOutputShape("out", 0, in_shape);
ctx->SetOutputShape("mask", 0, in_shape);
ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0));

Shape mask_shape = in_shape;
if (EnvBool<ONEFLOW_DROPOUT_MASK_USE_BITS>()) {
// JUST FOR NPU: Compute mask shape considering alignment(128)
const ParallelDesc& parallel_desc = ctx->parallel_desc();
const int64_t parallel_num = parallel_desc.parallel_num();
const int64_t elem_cnt = in_shape.elem_cnt();

const int64_t per_device_elem =
((elem_cnt + parallel_num - 1) / parallel_num + 127) / 128 * 128 / 8;
const int64_t global_aligned_size = parallel_num * per_device_elem;
mask_shape = Shape({global_aligned_size});
}

ctx->SetOutputShape("mask", 0, mask_shape);
ctx->SetOutputIsDynamic("mask", 0, ctx->InputIsDynamic("in", 0));
return Maybe<void>::Ok();
}

Expand All @@ -32,9 +50,22 @@ namespace oneflow {

/* static */ Maybe<void> DropoutOp::GetSbp(user_op::SbpContext* ctx) {
const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0);
FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) {
ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();

if (EnvBool<ONEFLOW_DROPOUT_MASK_USE_BITS>()) {
FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("in", 0), axis)
.Split(user_op::OpArg("out", 0), axis)
.Split(user_op::OpArg("mask", 0), 0)
.Build();
}

} else {
FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) {
ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();
}
}

return Maybe<void>::Ok();
}

Expand All @@ -47,15 +78,23 @@ namespace oneflow {

/* static */ Maybe<void> DropoutOp::InferDataType(user_op::InferContext* ctx) {
ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0));
ctx->SetOutputDType("mask", 0, DataType::kBool);
if (EnvBool<ONEFLOW_DROPOUT_MASK_USE_BITS>()) {
ctx->SetOutputDType("mask", 0, DataType::kUInt8);
} else {
ctx->SetOutputDType("mask", 0, DataType::kBool);
}
return Maybe<void>::Ok();
}

/* static */ Maybe<void> DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const Shape& dy_shape = ctx->InputShape("dy", 0);
ctx->SetOutputShape("dx", 0, dy_shape);
ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("dy", 0));
CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape);
// mask shape is same as dy_shape when using bytes
// mask shape is align(dy_shape.elem_cnt, 128) when using bits (NPU)
if (!EnvBool<ONEFLOW_DROPOUT_MASK_USE_BITS>()) {
CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape);
}
return Maybe<void>::Ok();
}

Expand All @@ -65,13 +104,25 @@ namespace oneflow {

/* static */ Maybe<void> DropoutGradOp::GetSbp(user_op::SbpContext* ctx) {
const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0);
FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), axis)
.Split(user_op::OpArg("mask", 0), axis)
.Split(user_op::OpArg("dx", 0), axis)
.Build();

if (EnvBool<ONEFLOW_DROPOUT_MASK_USE_BITS>()) {
FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), axis)
.Split(user_op::OpArg("dx", 0), axis)
.Split(user_op::OpArg("mask", 0), 0)
.Build();
}
} else {
FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), axis)
.Split(user_op::OpArg("mask", 0), axis)
.Split(user_op::OpArg("dx", 0), axis)
.Build();
}
}

return Maybe<void>::Ok();
}

Expand All @@ -84,9 +135,16 @@ namespace oneflow {

/* static */ Maybe<void> DropoutGradOp::InferDataType(user_op::InferContext* ctx) {
ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0));
CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool)
<< "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got "
<< DataType_Name(ctx->InputDType("mask", 0));

if (EnvBool<ONEFLOW_DROPOUT_MASK_USE_BITS>()) {
CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kUInt8)
<< "InferDataType Failed. Expected mask to be UINT8 in NPU, but got "
<< DataType_Name(ctx->InputDType("mask", 0));
} else {
CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool)
<< "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got "
<< DataType_Name(ctx->InputDType("mask", 0));
}
return Maybe<void>::Ok();
}

Expand Down