提交 74bac5aa authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made node_to_pattern a module level private function so it can be used

throughout the grad machinery
上级 da41c9da
...@@ -491,6 +491,42 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -491,6 +491,42 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
rval, = rval rval, = rval
return rval return rval
def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern
that does type checking and supplies the default value
if the method is not implemented
"""
if hasattr(node.op,'connection_pattern'):
connection_pattern = node.op.connection_pattern(node)
if not isinstance(connection_pattern, list):
raise TypeError("Op.connection_pattern should return " + \
("list of list of bool, but for Op=%s" % node.op) +\
"got %s with type %s." % (connection_pattern,
type(connection_pattern)))
if len(connection_pattern) != len(node.inputs):
raise ValueError('%s.connection_pattern should have %d' %
(node.op, len(node.inputs)) + 'rows but has %d.' %
len(connection_pattern))
for ii, output_pattern in enumerate(connection_pattern):
if not isinstance(output_pattern, list):
raise TypeError('%s.connection_pattern should return' %
node.op + ' a list of lists, but element %d' % ii\
+ 'is %s of type %s.' % (output_pattern,
type(output_pattern)))
else:
connection_pattern = \
[[True for output in node.outputs]
for ipt in node.inputs]
assert isinstance(connection_pattern,list)
assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list)
assert len(connection_pattern[ii]) == \
len(node.outputs)
return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt): def _populate_var_to_node_to_idx(outputs, wrt):
""" """
...@@ -520,40 +556,6 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -520,40 +556,6 @@ def _populate_var_to_node_to_idx(outputs, wrt):
""" """
def node_to_pattern(node):
# given an apply node, obtain its connection pattern
# this is just a wrapper around Op.connection_pattern
# that does type checking and supplies the default value
# if the method is not implemented
if hasattr(node.op,'connection_pattern'):
connection_pattern = node.op.connection_pattern(node)
if not isinstance(connection_pattern, list):
raise TypeError("Op.connection_pattern should return " + \
("list of list of bool, but for Op=%s" % node.op) +\
"got %s with type %s." % (connection_pattern,
type(connection_pattern)))
if len(connection_pattern) != len(node.inputs):
raise ValueError('%s.connection_pattern should have %d' %
(node.op, len(node.inputs)) + 'rows but has %d.' %
len(connection_pattern))
for ii, output_pattern in enumerate(connection_pattern):
if not isinstance(output_pattern, list):
raise TypeError('%s.connection_pattern should return' %
node.op + ' a list of lists, but element %d' % ii\
+ 'is %s of type %s.' % (output_pattern,
type(output_pattern)))
else:
connection_pattern = \
[[True for output in node.outputs]
for ipt in node.inputs]
assert isinstance(connection_pattern,list)
assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list)
assert len(connection_pattern[ii]) == \
len(node.outputs)
return connection_pattern
# var_to_node_to_idx[var][node] = [i,j] means node has # var_to_node_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j # var as input at positions i and j
var_to_node_to_idx = {} var_to_node_to_idx = {}
...@@ -573,7 +575,7 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -573,7 +575,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
if node not in accounted_for: if node not in accounted_for:
accounted_for.add(node) accounted_for.add(node)
connection_pattern = node_to_pattern(node) connection_pattern = _node_to_pattern(node)
var_idx = node.outputs.index(var) var_idx = node.outputs.index(var)
...@@ -611,7 +613,7 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -611,7 +613,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
visited.add(var) visited.add(var)
nodes = var_to_node_to_idx[var] nodes = var_to_node_to_idx[var]
for node in nodes: for node in nodes:
connection_pattern = node_to_pattern(node) connection_pattern = _node_to_pattern(node)
for idx in nodes[node]: for idx in nodes[node]:
for ii, output in enumerate(node.outputs): for ii, output in enumerate(node.outputs):
if connection_pattern[idx][ii]: if connection_pattern[idx][ii]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论