提交 a3db430f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where a variable appearing twice in an inputs list caused

incorrect gradient
上级 3b546b82
...@@ -466,7 +466,7 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -466,7 +466,7 @@ def _populate_var_to_node_to_idx(outputs):
""" """
#var_to_node_to_idx[var][node] = i means node has var as input at position i #var_to_node_to_idx[var][node] = [i,j] means node has var as input at positions i and j
var_to_node_to_idx = {} var_to_node_to_idx = {}
#set of variables that have been added to their parents #set of variables that have been added to their parents
accounted_for = set([]) accounted_for = set([])
...@@ -480,11 +480,18 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -480,11 +480,18 @@ def _populate_var_to_node_to_idx(outputs):
accounted_for.add(var) accounted_for.add(var)
if var.owner is not None: if var.owner is not None:
node = var.owner node = var.owner
for i, ipt in enumerate(node.inputs): if node not in accounted_for:
if ipt not in var_to_node_to_idx: accounted_for.add(node)
var_to_node_to_idx[ipt] = {} for i, ipt in enumerate(node.inputs):
var_to_node_to_idx[ipt][node] = i if ipt not in var_to_node_to_idx:
account_for(ipt) var_to_node_to_idx[ipt] = {}
node_to_idx = var_to_node_to_idx[ipt]
if node not in node_to_idx:
node_to_idx[node] = []
idx = node_to_idx[node]
assert i not in idx
idx.append(i)
account_for(ipt)
for output in outputs: for output in outputs:
account_for(output) account_for(output)
...@@ -584,19 +591,25 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -584,19 +591,25 @@ def _populate_grad_dict(var_to_node_to_idx,\
terms = [] terms = []
node_to_idx = var_to_node_to_idx[var] node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx: for node in node_to_idx:
idx = node_to_idx[node] for idx in node_to_idx[node]:
term = access_term_cache(node)[idx]
if hasattr(node.op, 'connection_pattern'):
pattern = node.op.connection_pattern()
if not pattern[idx]:
continue
term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable): if not isinstance(term, gof.Variable):
raise TypeError("%s.grad returned %s, expected" raise TypeError("%s.grad returned %s, expected"
" Variable instance." % (str(node.op), " Variable instance." % (str(node.op),
type(term))) type(term)))
if isinstance(term.type,NaNType): if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\ raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan) term.type.why_nan)
terms.append( term) terms.append( term)
grad_dict[var] = nonempty_sum(terms) grad_dict[var] = nonempty_sum(terms)
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论