提交 d85e6795 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bugs introduced by recent commits

上级 896fdba3
......@@ -668,8 +668,9 @@ def _populate_grad_dict(var_to_node_to_idx,
# its inputs' gradients
term_dict = {}
# populate term_dict[node] and return it
def access_term_cache(node):
""" Populates term_dict[node] and returns it """
if node not in term_dict:
inputs = node.inputs
......@@ -684,15 +685,13 @@ def _populate_grad_dict(var_to_node_to_idx,
# list of bools indicating if each input is connected to the cost
inputs_connected = [
[ input_to_output and output_to_cost for
(True in [ input_to_output and output_to_cost for
input_to_output, output_to_cost in
zip(input_to_outputs, outputs_connected) ] for
zip(input_to_outputs, outputs_connected) ]) for
input_to_outputs in connection_pattern
]
all_inputs_disconnected = not ( True in inputs_connected )
if all_inputs_disconnected:
if True in inputs_connected:
# At least one input of this op is connected to the cost so we must
# call the op's grad method
......@@ -763,7 +762,7 @@ def _populate_grad_dict(var_to_node_to_idx,
node.op, g_r_type, i, r_type)
#cache the result
term_dict[node] = list(input_grads)
term_dict[node] = input_grads
return term_dict[node]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论