提交 46ebe84a authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Method for getting the connection_pattern of scan's inner function

上级 39b4ec6e
......@@ -24,6 +24,7 @@ from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano.gof.graph import io_toposort
from theano.compat.python2x import any, OrderedDict
from theano.tensor import TensorType
from theano.tensor.opt import Shape_i
......@@ -1305,6 +1306,71 @@ class Scan(PureOp):
ipos += len(otaps)
return ipos + opos
def inner_connection_pattern(self, node):
""" Returns the connection pattern of scan's inner function
"""
inner_nodes = io_toposort(self.inputs, self.outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(self.inputs)
nb_outputs = len(self.outputs)
for i in range(nb_inputs):
input = self.inputs[i]
inp_connection_pattern = [i==j for j in range(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = ([[True] * len(n.outputs)] *
len(n.inputs))
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in range(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in range(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [out_connection_pattern[i] or
inp_connection_pattern[i]
for i in range(nb_inputs)]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(self.inputs))]
for out in self.outputs:
out_connection_pattern = connect_pattern_by_var[out]
for i in range(len(self.inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
def connection_pattern(self, node):
# We cache this, as grad call connection_pattern, and it call
# grad in its turn. I was a case where theano.grad() took 4h
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论