提交 a3db430f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where a variable appearing twice in an inputs list caused

incorrect gradient
上级 3b546b82
...@@ -466,7 +466,7 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -466,7 +466,7 @@ def _populate_var_to_node_to_idx(outputs):
""" """
#var_to_node_to_idx[var][node] = i means node has var as input at position i #var_to_node_to_idx[var][node] = [i,j] means node has var as input at positions i and j
var_to_node_to_idx = {} var_to_node_to_idx = {}
#set of variables that have been added to their parents #set of variables that have been added to their parents
accounted_for = set([]) accounted_for = set([])
...@@ -480,10 +480,17 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -480,10 +480,17 @@ def _populate_var_to_node_to_idx(outputs):
accounted_for.add(var) accounted_for.add(var)
if var.owner is not None: if var.owner is not None:
node = var.owner node = var.owner
if node not in accounted_for:
accounted_for.add(node)
for i, ipt in enumerate(node.inputs): for i, ipt in enumerate(node.inputs):
if ipt not in var_to_node_to_idx: if ipt not in var_to_node_to_idx:
var_to_node_to_idx[ipt] = {} var_to_node_to_idx[ipt] = {}
var_to_node_to_idx[ipt][node] = i node_to_idx = var_to_node_to_idx[ipt]
if node not in node_to_idx:
node_to_idx[node] = []
idx = node_to_idx[node]
assert i not in idx
idx.append(i)
account_for(ipt) account_for(ipt)
for output in outputs: for output in outputs:
...@@ -584,7 +591,13 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -584,7 +591,13 @@ def _populate_grad_dict(var_to_node_to_idx,\
terms = [] terms = []
node_to_idx = var_to_node_to_idx[var] node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx: for node in node_to_idx:
idx = node_to_idx[node] for idx in node_to_idx[node]:
if hasattr(node.op, 'connection_pattern'):
pattern = node.op.connection_pattern()
if not pattern[idx]:
continue
term = access_term_cache(node)[idx] term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable): if not isinstance(term, gof.Variable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论