提交 f90b7022 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made gradient.grad smarter about not calling Op.grad

(this shouldn't change the output in any case, just eliminates unnecessary calls to Op.grad)
上级 74bac5aa
...@@ -676,9 +676,24 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -676,9 +676,24 @@ def _populate_grad_dict(var_to_node_to_idx,
output_grads = [access_grad_cache(var) for var in node.outputs] output_grads = [access_grad_cache(var) for var in node.outputs]
if False in [isinstance(g.type, DisconnectedType) # list of bools indicating if each output is connected to the cost
for g in output_grads]: outputs_connected = [ not isinstance(g.type, DisconnectedType)
# Some outputs of this op are connected to the cost so we must for g in output_grads ]
connection_pattern = _node_to_pattern(node)
# list of bools indicating if each input is connected to the cost
inputs_connected = [
[ input_to_output and output_to_cost for
input_to_output, output_to_cost in
zip(input_to_outputs, outputs_connected) ] for
input_to_outputs in connection_pattern
]
all_inputs_disconnected = not ( True in inputs_connected )
if all_inputs_disconnected:
# At least one input of this op is connected to the cost so we must
# call the op's grad method # call the op's grad method
# Each Op's grad function requires inputs and output_grads # Each Op's grad function requires inputs and output_grads
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论