提交 74bac5aa authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made node_to_pattern a module level private function so it can be used

throughout the grad machinery
上级 da41c9da
...@@ -491,40 +491,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -491,40 +491,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
rval, = rval rval, = rval
return rval return rval
def _node_to_pattern(node):
def _populate_var_to_node_to_idx(outputs, wrt): """ given an apply node, obtain its connection pattern
""" this is just a wrapper around Op.connection_pattern
Common code shared between grad and grad_sources_inputs that does type checking and supplies the default value
if the method is not implemented
outputs: a list of variables we want to take gradients of
wrt: a list of variables we want to take the gradient with
respect to.
returns:
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function
of the elements of at least one member of
wrt
This set is exactly the set of variables that
connect the variables in wrt to the cost being
differentiated.
""" """
def node_to_pattern(node):
# given an apply node, obtain its connection pattern
# this is just a wrapper around Op.connection_pattern
# that does type checking and supplies the default value
# if the method is not implemented
if hasattr(node.op,'connection_pattern'): if hasattr(node.op,'connection_pattern'):
connection_pattern = node.op.connection_pattern(node) connection_pattern = node.op.connection_pattern(node)
...@@ -554,6 +527,35 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -554,6 +527,35 @@ def _populate_var_to_node_to_idx(outputs, wrt):
assert len(connection_pattern[ii]) == \ assert len(connection_pattern[ii]) == \
len(node.outputs) len(node.outputs)
return connection_pattern return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt):
"""
Common code shared between grad and grad_sources_inputs
outputs: a list of variables we want to take gradients of
wrt: a list of variables we want to take the gradient with
respect to.
returns:
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function
of the elements of at least one member of
wrt
This set is exactly the set of variables that
connect the variables in wrt to the cost being
differentiated.
"""
# var_to_node_to_idx[var][node] = [i,j] means node has # var_to_node_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j # var as input at positions i and j
var_to_node_to_idx = {} var_to_node_to_idx = {}
...@@ -573,7 +575,7 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -573,7 +575,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
if node not in accounted_for: if node not in accounted_for:
accounted_for.add(node) accounted_for.add(node)
connection_pattern = node_to_pattern(node) connection_pattern = _node_to_pattern(node)
var_idx = node.outputs.index(var) var_idx = node.outputs.index(var)
...@@ -611,7 +613,7 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -611,7 +613,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
visited.add(var) visited.add(var)
nodes = var_to_node_to_idx[var] nodes = var_to_node_to_idx[var]
for node in nodes: for node in nodes:
connection_pattern = node_to_pattern(node) connection_pattern = _node_to_pattern(node)
for idx in nodes[node]: for idx in nodes[node]:
for ii, output in enumerate(node.outputs): for ii, output in enumerate(node.outputs):
if connection_pattern[idx][ii]: if connection_pattern[idx][ii]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论