提交 2484dde0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix bug in gradient when some outputs of a node have different connection patterns

(Razvan's branch will introduce a test that catches this bug)
上级 5d568c17
......@@ -602,31 +602,39 @@ def _populate_var_to_node_to_idx(outputs, wrt):
respect to.
returns:
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function
of the elements of at least one member of
wrt
This set is exactly the set of variables that
connect the variables in wrt to the cost being
differentiated.
var_to_app_to_idx:
A dictionary mapping a variable to a second dictionary.
The second dictionary maps apply nodes acting on this
variable to the variable's index in the apply node's
input list.
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function of the
elements of at least one member of wrt.
This set is exactly the set of variables that connect
the variables in wrt to the cost being differentiated.
"""
# var_to_node_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
var_to_node_to_idx = {}
# set of variables or nodes that have been added to their true parents
var_to_app_to_idx = {}
# Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function
# of the elements of the parent, according to the op's
# connection_pattern)
# Note: we need to revisit the apply nodes repeatedly, because
# different outputs of the apply node are connected to
# different subsets of the inputs.
accounted_for = set([])
def account_for(var):
......@@ -634,30 +642,28 @@ def _populate_var_to_node_to_idx(outputs, wrt):
return
accounted_for.add(var)
if var.owner is not None:
node = var.owner
if node not in accounted_for:
accounted_for.add(node)
app = var.owner
connection_pattern = _node_to_pattern(node)
connection_pattern = _node_to_pattern(app)
var_idx = node.outputs.index(var)
var_idx = app.outputs.index(var)
for i, ipt in enumerate(node.inputs):
for i, ipt in enumerate(app.inputs):
#don't process ipt if it is not a true
#parent of var
if not connection_pattern[i][var_idx]:
continue
#don't process ipt if it is not a true
#parent of var
if not connection_pattern[i][var_idx]:
continue
if ipt not in var_to_node_to_idx:
var_to_node_to_idx[ipt] = {}
node_to_idx = var_to_node_to_idx[ipt]
if node not in node_to_idx:
node_to_idx[node] = []
idx = node_to_idx[node]
assert i not in idx
idx.append(i)
account_for(ipt)
if ipt not in var_to_app_to_idx:
var_to_app_to_idx[ipt] = {}
app_to_idx = var_to_app_to_idx[ipt]
if app not in app_to_idx:
app_to_idx[app] = []
idx = app_to_idx[app]
assert i not in idx
idx.append(i)
account_for(ipt)
# add all variables that are true ancestors of the cost
for output in outputs:
......@@ -671,10 +677,10 @@ def _populate_var_to_node_to_idx(outputs, wrt):
def visit(var):
if var in visited:
return
if var not in var_to_node_to_idx:
if var not in var_to_app_to_idx:
return
visited.add(var)
nodes = var_to_node_to_idx[var]
nodes = var_to_app_to_idx[var]
for node in nodes:
connection_pattern = _node_to_pattern(node)
for idx in nodes[node]:
......@@ -686,12 +692,12 @@ def _populate_var_to_node_to_idx(outputs, wrt):
visit(elem)
# Remove variables that don't have wrt as a true ancestor
orig_vars = list(var_to_node_to_idx.keys())
orig_vars = list(var_to_app_to_idx.keys())
for var in orig_vars:
if var not in visited:
del var_to_node_to_idx[var]
del var_to_app_to_idx[var]
return var_to_node_to_idx
return var_to_app_to_idx
def _populate_grad_dict(var_to_node_to_idx,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论