提交 d85e6795 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bugs introduced by recent commits

上级 896fdba3
...@@ -668,8 +668,9 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -668,8 +668,9 @@ def _populate_grad_dict(var_to_node_to_idx,
# its inputs' gradients # its inputs' gradients
term_dict = {} term_dict = {}
# populate term_dict[node] and return it
def access_term_cache(node): def access_term_cache(node):
""" Populates term_dict[node] and returns it """
if node not in term_dict: if node not in term_dict:
inputs = node.inputs inputs = node.inputs
...@@ -684,15 +685,13 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -684,15 +685,13 @@ def _populate_grad_dict(var_to_node_to_idx,
# list of bools indicating if each input is connected to the cost # list of bools indicating if each input is connected to the cost
inputs_connected = [ inputs_connected = [
[ input_to_output and output_to_cost for (True in [ input_to_output and output_to_cost for
input_to_output, output_to_cost in input_to_output, output_to_cost in
zip(input_to_outputs, outputs_connected) ] for zip(input_to_outputs, outputs_connected) ]) for
input_to_outputs in connection_pattern input_to_outputs in connection_pattern
] ]
all_inputs_disconnected = not ( True in inputs_connected ) if True in inputs_connected:
if all_inputs_disconnected:
# At least one input of this op is connected to the cost so we must # At least one input of this op is connected to the cost so we must
# call the op's grad method # call the op's grad method
...@@ -763,7 +762,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -763,7 +762,7 @@ def _populate_grad_dict(var_to_node_to_idx,
node.op, g_r_type, i, r_type) node.op, g_r_type, i, r_type)
#cache the result #cache the result
term_dict[node] = list(input_grads) term_dict[node] = input_grads
return term_dict[node] return term_dict[node]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论