Skip to content

Commit cbc4d1d

Browse files
authored
fix recover porgram on dropout op (#1859)
1 parent 574c816 commit cbc4d1d

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

paddleslim/common/recover_program.py

+11
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def _recover_outputs_attr(program):
6161
persistable=False,
6262
stop_gradient=True)
6363
op.desc.set_output("XShape", [xshape.name])
64+
if op.type == 'dropout':
65+
if 'Mask' not in op.output_names:
66+
mask = block.create_var(
67+
name=paddle.utils.unique_name.
68+
generate_with_ignorable_key(".".join(["mask", 'tmp'])),
69+
dtype=block.var(op.input("X")[0]).dtype,
70+
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
71+
shape=block.var(op.input("X")[0]).shape,
72+
persistable=False,
73+
stop_gradient=True)
74+
op.desc.set_output("Mask", [mask.name])
6475
return program
6576

6677

0 commit comments

Comments
 (0)