We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 574c816 commit cbc4d1dCopy full SHA for cbc4d1d
paddleslim/common/recover_program.py
@@ -61,6 +61,17 @@ def _recover_outputs_attr(program):
61
persistable=False,
62
stop_gradient=True)
63
op.desc.set_output("XShape", [xshape.name])
64
+ if op.type == 'dropout':
65
+ if 'Mask' not in op.output_names:
66
+ mask = block.create_var(
67
+ name=paddle.utils.unique_name.
68
+ generate_with_ignorable_key(".".join(["mask", 'tmp'])),
69
+ dtype=block.var(op.input("X")[0]).dtype,
70
+ type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
71
+ shape=block.var(op.input("X")[0]).shape,
72
+ persistable=False,
73
+ stop_gradient=True)
74
+ op.desc.set_output("Mask", [mask.name])
75
return program
76
77
0 commit comments