提交 74bac5aa authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made node_to_pattern a module level private function so it can be used

throughout the grad machinery
上级 da41c9da
......@@ -491,40 +491,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
rval, = rval
return rval
def _populate_var_to_node_to_idx(outputs, wrt):
"""
Common code shared between grad and grad_sources_inputs
outputs: a list of variables we want to take gradients of
wrt: a list of variables we want to take the gradient with
respect to.
returns:
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function
of the elements of at least one member of
wrt
This set is exactly the set of variables that
connect the variables in wrt to the cost being
differentiated.
def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern
that does type checking and supplies the default value
if the method is not implemented
"""
def node_to_pattern(node):
# given an apply node, obtain its connection pattern
# this is just a wrapper around Op.connection_pattern
# that does type checking and supplies the default value
# if the method is not implemented
if hasattr(node.op,'connection_pattern'):
connection_pattern = node.op.connection_pattern(node)
......@@ -554,6 +527,35 @@ def _populate_var_to_node_to_idx(outputs, wrt):
assert len(connection_pattern[ii]) == \
len(node.outputs)
return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt):
"""
Common code shared between grad and grad_sources_inputs
outputs: a list of variables we want to take gradients of
wrt: a list of variables we want to take the gradient with
respect to.
returns:
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function
of the elements of at least one member of
wrt
This set is exactly the set of variables that
connect the variables in wrt to the cost being
differentiated.
"""
# var_to_node_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
var_to_node_to_idx = {}
......@@ -573,7 +575,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
if node not in accounted_for:
accounted_for.add(node)
connection_pattern = node_to_pattern(node)
connection_pattern = _node_to_pattern(node)
var_idx = node.outputs.index(var)
......@@ -611,7 +613,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
visited.add(var)
nodes = var_to_node_to_idx[var]
for node in nodes:
connection_pattern = node_to_pattern(node)
connection_pattern = _node_to_pattern(node)
for idx in nodes[node]:
for ii, output in enumerate(node.outputs):
if connection_pattern[idx][ii]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论