提交 93f3b868 authored 作者: nouiz's avatar nouiz

Merge pull request #1041 from goodfeli/fix_grad

fix bug in gradient when some outputs of a node have different connectio...
...@@ -602,31 +602,39 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -602,31 +602,39 @@ def _populate_var_to_node_to_idx(outputs, wrt):
respect to. respect to.
returns: returns:
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary. var_to_app_to_idx:
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply A dictionary mapping a variable to a second dictionary.
node's input list The second dictionary maps apply nodes acting on this
This dictionary will only contain variables that variable to the variable's index in the apply node's
meet two criteria: input list.
1) The elements of at least one output are a
function of the elements of the variable This dictionary will only contain variables that
2) The elements of the variable are a function meet two criteria:
of the elements of at least one member of
wrt 1) The elements of at least one output are a
This set is exactly the set of variables that function of the elements of the variable
connect the variables in wrt to the cost being
differentiated. 2) The elements of the variable are a function of the
elements of at least one member of wrt.
This set is exactly the set of variables that connect
the variables in wrt to the cost being differentiated.
""" """
# var_to_node_to_idx[var][node] = [i,j] means node has # var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j # var as input at positions i and j
var_to_node_to_idx = {} var_to_app_to_idx = {}
# set of variables or nodes that have been added to their true parents
# Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function # ('true' here means that the elements of the variable are a function
# of the elements of the parent, according to the op's # of the elements of the parent, according to the op's
# connection_pattern) # connection_pattern)
# Note: we need to revisit the apply nodes repeatedly, because
# different outputs of the apply node are connected to
# different subsets of the inputs.
accounted_for = set([]) accounted_for = set([])
def account_for(var): def account_for(var):
...@@ -634,30 +642,28 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -634,30 +642,28 @@ def _populate_var_to_node_to_idx(outputs, wrt):
return return
accounted_for.add(var) accounted_for.add(var)
if var.owner is not None: if var.owner is not None:
node = var.owner app = var.owner
if node not in accounted_for:
accounted_for.add(node)
connection_pattern = _node_to_pattern(node) connection_pattern = _node_to_pattern(app)
var_idx = node.outputs.index(var) var_idx = app.outputs.index(var)
for i, ipt in enumerate(node.inputs): for i, ipt in enumerate(app.inputs):
#don't process ipt if it is not a true #don't process ipt if it is not a true
#parent of var #parent of var
if not connection_pattern[i][var_idx]: if not connection_pattern[i][var_idx]:
continue continue
if ipt not in var_to_node_to_idx: if ipt not in var_to_app_to_idx:
var_to_node_to_idx[ipt] = {} var_to_app_to_idx[ipt] = {}
node_to_idx = var_to_node_to_idx[ipt] app_to_idx = var_to_app_to_idx[ipt]
if node not in node_to_idx: if app not in app_to_idx:
node_to_idx[node] = [] app_to_idx[app] = []
idx = node_to_idx[node] idx = app_to_idx[app]
assert i not in idx if i not in idx:
idx.append(i) idx.append(i)
account_for(ipt) account_for(ipt)
# add all variables that are true ancestors of the cost # add all variables that are true ancestors of the cost
for output in outputs: for output in outputs:
...@@ -671,10 +677,10 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -671,10 +677,10 @@ def _populate_var_to_node_to_idx(outputs, wrt):
def visit(var): def visit(var):
if var in visited: if var in visited:
return return
if var not in var_to_node_to_idx: if var not in var_to_app_to_idx:
return return
visited.add(var) visited.add(var)
nodes = var_to_node_to_idx[var] nodes = var_to_app_to_idx[var]
for node in nodes: for node in nodes:
connection_pattern = _node_to_pattern(node) connection_pattern = _node_to_pattern(node)
for idx in nodes[node]: for idx in nodes[node]:
...@@ -686,12 +692,12 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -686,12 +692,12 @@ def _populate_var_to_node_to_idx(outputs, wrt):
visit(elem) visit(elem)
# Remove variables that don't have wrt as a true ancestor # Remove variables that don't have wrt as a true ancestor
orig_vars = list(var_to_node_to_idx.keys()) orig_vars = list(var_to_app_to_idx.keys())
for var in orig_vars: for var in orig_vars:
if var not in visited: if var not in visited:
del var_to_node_to_idx[var] del var_to_app_to_idx[var]
return var_to_node_to_idx return var_to_app_to_idx
def _populate_grad_dict(var_to_node_to_idx, def _populate_grad_dict(var_to_node_to_idx,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论