提交 43e64647 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

copy only those inputs marked for destruction

上级 40381bca
......@@ -544,18 +544,22 @@ def _populate_grad_dict(var_to_node_to_idx,\
inputs = node.inputs
def try_to_copy(var):
if hasattr(var,'copy'):
# Each Op's grad function requires inputs and output_grads
# If the Op destroys any input, but the grad expression uses it,
# then chances are the resulting graph will have a dependency
# cycle. We avoid this cycle by passing (symbolic) copies of
# each destroyed input.
try:
dinputs = [node.inputs[x[0]] for x in node.op.destroy_map.values()]
except AttributeError:
dinputs = []
def try_to_copy_if_needed(var):
if var in dinputs and hasattr(var,'copy'):
return var.copy()
return var
#inplace ops often have inplace in their expression for the gradient
#this can result in cyclical dependencies, ie there not being an order
#in which we can run all the resulting inplace ops without destroying
#some op's input before the time that it is needed
#to get around this, we try to symbolically copy all of the inputs
#so it is only the copy that is destroyed
inputs = [try_to_copy(ipt) for ipt in inputs ]
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs ]
output_grads = [ access_grad_cache(var) for var in node.outputs ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论