diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index 10f8b384ed7..39d610ad747 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -18,11 +18,29 @@ limitations under the License. namespace oneflow { +//ONEFLOW_DROPOUT_MASK_USE_BITS is true when using NPU +DEFINE_ENV_BOOL(ONEFLOW_DROPOUT_MASK_USE_BITS, false); + /* static */ Maybe DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& in_shape = ctx->InputShape("in", 0); ctx->SetOutputShape("out", 0, in_shape); - ctx->SetOutputShape("mask", 0, in_shape); ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); + + Shape mask_shape = in_shape; + if (EnvBool()) { + // JUST FOR NPU: Compute mask shape considering alignment(128) + const ParallelDesc& parallel_desc = ctx->parallel_desc(); + const int64_t parallel_num = parallel_desc.parallel_num(); + const int64_t elem_cnt = in_shape.elem_cnt(); + + const int64_t per_device_elem = + ((elem_cnt + parallel_num - 1) / parallel_num + 127) / 128 * 128 / 8; + const int64_t global_aligned_size = parallel_num * per_device_elem; + mask_shape = Shape({global_aligned_size}); + } + + ctx->SetOutputShape("mask", 0, mask_shape); + ctx->SetOutputIsDynamic("mask", 0, ctx->InputIsDynamic("in", 0)); return Maybe::Ok(); } @@ -32,9 +50,22 @@ namespace oneflow { /* static */ Maybe DropoutOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); + + if (EnvBool()) { + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Split(user_op::OpArg("mask", 0), 0) + .Build(); + } + + } else { + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); + } } + return Maybe::Ok(); } @@ -47,7 +78,11 @@ namespace oneflow { /* static */ Maybe DropoutOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); - ctx->SetOutputDType("mask", 0, DataType::kBool); + if (EnvBool()) { + ctx->SetOutputDType("mask", 0, DataType::kUInt8); + } else { + ctx->SetOutputDType("mask", 0, DataType::kBool); + } return Maybe::Ok(); } @@ -55,7 +90,11 @@ namespace oneflow { const Shape& dy_shape = ctx->InputShape("dy", 0); ctx->SetOutputShape("dx", 0, dy_shape); ctx->SetOutputIsDynamic("dx", 0, ctx->InputIsDynamic("dy", 0)); - CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); + // mask shape is same as dy_shape when using bytes + // mask shape is align(dy_shape.elem_cnt, 128) when using bits (NPU) + if (!EnvBool()) { + CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); + } return Maybe::Ok(); } @@ -65,13 +104,25 @@ namespace oneflow { /* static */ Maybe DropoutGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); + + if (EnvBool()) { + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Split(user_op::OpArg("mask", 0), 0) + .Build(); + } + } else { + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } } + return Maybe::Ok(); } @@ -84,9 +135,16 @@ namespace oneflow { /* static */ Maybe DropoutGradOp::InferDataType(user_op::InferContext* ctx) { ctx->SetOutputDType("dx", 0, ctx->InputDType("dy", 0)); - CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool) - << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " - << DataType_Name(ctx->InputDType("mask", 0)); + + if (EnvBool()) { + CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kUInt8) + << "InferDataType Failed. Expected mask to be UINT8 in NPU, but got " + << DataType_Name(ctx->InputDType("mask", 0)); + } else { + CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kBool) + << "InferDataType Failed. Expected " << DataType_Name(DataType::kBool) << ", but got " + << DataType_Name(ctx->InputDType("mask", 0)); + } return Maybe::Ok(); }