提交 ccdcc319 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

moved some preprocessing code inside a branch so it only executes if

needed
上级 79c950fa
...@@ -672,30 +672,30 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -672,30 +672,30 @@ def _populate_grad_dict(var_to_node_to_idx,
inputs = node.inputs inputs = node.inputs
# Each Op's grad function requires inputs and output_grads
# If the Op destroys any input, but the grad expression uses it,
# then chances are the resulting graph will have a dependency
# cycle. We avoid this cycle by passing (symbolic) copies of
# each destroyed input.
try:
dinputs = [node.inputs[x[0]] for x in
node.op.destroy_map.values()]
except AttributeError:
dinputs = []
def try_to_copy_if_needed(var):
if var in dinputs and hasattr(var, 'copy'):
return var.copy()
return var
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
output_grads = [access_grad_cache(var) for var in node.outputs] output_grads = [access_grad_cache(var) for var in node.outputs]
if False in [isinstance(g.type, DisconnectedType) if False in [isinstance(g.type, DisconnectedType)
for g in output_grads]: for g in output_grads]:
# Some outputs of this op are connected to the cost so we must # Some outputs of this op are connected to the cost so we must
# call the ops grad method # call the op's grad method
# Each Op's grad function requires inputs and output_grads
# If the Op destroys any input, but the grad expression uses it,
# then chances are the resulting graph will have a dependency
# cycle. We avoid this cycle by passing (symbolic) copies of
# each destroyed input.
try:
dinputs = [node.inputs[x[0]] for x in
node.op.destroy_map.values()]
except AttributeError:
dinputs = []
def try_to_copy_if_needed(var):
if var in dinputs and hasattr(var, 'copy'):
return var.copy()
return var
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
input_grads = node.op.grad(inputs, output_grads) input_grads = node.op.grad(inputs, output_grads)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论