提交 bf17e33b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed gradient bug with inplace operations

上级 69f2323f
...@@ -81,6 +81,22 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -81,6 +81,22 @@ def grad_sources_inputs(sources, graph_inputs):
output_arg = g_outputs output_arg = g_outputs
input_arg = op.inputs input_arg = op.inputs
try:
dinputs = [x[0] for x in op.destroy_map().values()]
except AttributeError:
dinputs = []
# input_arg = [input in dinputs and input.copy() or input for input in input_arg]
new_input_arg = []
for input in input_arg:
if input in dinputs:
new_input_arg.append(input.copy())
else:
new_input_arg.append(input)
input_arg = new_input_arg
op_grad = op.grad(input_arg, output_arg) op_grad = op.grad(input_arg, output_arg)
if not isinstance(op_grad, (list,tuple)): if not isinstance(op_grad, (list,tuple)):
raise ValueError(_msg_retType, op.__class__) raise ValueError(_msg_retType, op.__class__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论