提交 b8bba33c authored 作者: ChienliMa's avatar ChienliMa

Move scanOP.inner_connection_pattern to gof.graph and reuse it in OpFromGraph

上级 97ab0274
......@@ -6,6 +6,7 @@ from theano.compat import izip
from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function, FunctionGraph
from theano.gof.graph import io_connection_pattern
class OpFromGraph(gof.Op):
......@@ -138,58 +139,9 @@ class OpFromGraph(gof.Op):
def connection_pattern(self, node):
"""
Connection_pattern is hard to calculate. In the function, we calculate
the transpose of connection_pattern, where M[output_index,input_index]
indicates whether input with index i affects output with index i.
At last we return the transpose of final result
Return connection pattern of subfgraph defined by inputs and outputs
"""
# or ori_inputs because user do not customize sharejvariable
fgraph = FunctionGraph(self.new_inputs, self.new_outputs)
# c for connection, stores the connection pattern of each variable
c_map = {}
num_of_input = len(fgraph.inputs)
# Initialize input connection pattern, each input affects itself
for index in xrange(num_of_input):
vec = [False] * num_of_input
vec[index] = True
# Make use of numpy.array to simplify codes
c_map.setdefault(fgraph.inputs[index], numpy.array(vec))
# Toposort the fgraph and get connection pattern for each variable
for node in fgraph.toposort():
# connection pattern of node's inputs.
in_vecs = []
for var in node.inputs:
if not isinstance(var, theano.Constant):
in_vecs.append(c_map[var])
else:
in_vecs.append(numpy.array([False] * num_of_input))
if not hasattr(node.op, 'connection_pattern'):
# By default, nodes inputs affect all outputs
result = in_vecs[0]
for vec in in_vecs[1:]:
result |= vec
results = result * len(node.outputs)
else:
# If node's output connect to node's input, and that input
# connect to fgraph.input, that output connect to fgraph.input
# Therefore we use OR operation here.
results = []
out_vecs = numpy.array(node.op.connection_pattern(node))
for out_vec in out_vecs.T:
result = [False] * num_of_input
for in_vec, val in zip(in_vecs, out_vec):
result |= (in_vec & val)
results.append(result)
for var, result in zip(node.outputs, results):
c_map.setdefault(var, result)
# Transpose final result and convert pattern into python list
pattern = numpy.array([c_map[var] for var in fgraph.outputs]).T
return [list(vec) for vec in pattern]
return io_connection_pattern(self.new_inputs, self.new_outputs)
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
......
......@@ -862,6 +862,74 @@ default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
", ".join(argstrings))
def io_connection_pattern(inputs, outputs):
"""
Returns the connection pattern of a subgraph defined by given
inputs and outputs
"""
inner_nodes = io_toposort(inputs, outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(inputs)
nb_outputs = len(outputs)
for i in xrange(nb_inputs):
input = inputs[i]
inp_connection_pattern = [i == j for j in xrange(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = ([[True] * len(n.outputs)] *
len(n.inputs))
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in xrange(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in xrange(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [out_connection_pattern[i] or
inp_connection_pattern[i]
for i in xrange(nb_inputs)]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in xrange(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var[out]
for i in xrange(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
def is_same_graph(var1, var2, givens=None, debug=False):
"""
Return True iff Variables `var1` and `var2` perform the same computation.
......
......@@ -68,7 +68,7 @@ from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano.gof.graph import io_toposort
from theano.gof.graph import io_connection_pattern
from theano.compat import OrderedDict, izip
from theano.tensor import TensorType
from theano.tensor.opt import Shape_i
......@@ -1471,71 +1471,6 @@ class Scan(PureOp):
scan_outs.append((Shape_i(0)(o),) + x[1:])
return scan_outs
def inner_connection_pattern(self):
""" Returns the connection pattern of scan's inner function
"""
inner_nodes = io_toposort(self.inputs, self.outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(self.inputs)
nb_outputs = len(self.outputs)
for i in xrange(nb_inputs):
input = self.inputs[i]
inp_connection_pattern = [i == j for j in xrange(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = ([[True] * len(n.outputs)] *
len(n.inputs))
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in xrange(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in xrange(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [out_connection_pattern[i] or
inp_connection_pattern[i]
for i in xrange(nb_inputs)]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in xrange(len(self.inputs))]
for out in self.outputs:
out_connection_pattern = connect_pattern_by_var[out]
for i in xrange(len(self.inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
def connection_pattern(self, node):
# We cache the result of this function because, with a previous
......@@ -1546,7 +1481,7 @@ class Scan(PureOp):
return node.tag.connection_pattern
# Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern()
inner_connect_pattern = io_connection_pattern(self.inputs, self.outputs)
# Initially assume no outer input is connected to any outer output
connection_pattern = [[False for output in node.outputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论