提交 25706f6c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3118 from ChienliMa/connection_pattern

OpFromGraph.connection_pattern()
......@@ -4,6 +4,7 @@ from theano.compat import izip
from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern
class OpFromGraph(gof.Op):
......@@ -134,6 +135,12 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore
output[0] = variable.copy()
def connection_pattern(self, node):
"""
Return connection pattern of subfgraph defined by inputs and outputs
"""
return io_connection_pattern(self.new_inputs, self.new_outputs)
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will
......
......@@ -7,6 +7,7 @@ from theano.compile import function
from theano import tensor
from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraph
......@@ -107,6 +108,48 @@ class T_OpFromGraph(unittest.TestCase):
fn = function([x, y, z], f)
assert numpy.allclose(15.0 + s.get_value(),
fn(xv, yv, zv))
def test_connection_pattern(self):
# Basic case
x, y, z = T.matrices('xyz')
out1 = x * y
out2 = y * z
op1 = OpFromGraph([x ,y, z], [out1, out2], mode='FAST_RUN')
results = op1.connection_pattern(None)
expect_result = [[True, False],
[True, True],
[False, True]]
assert results == expect_result
# Graph with ops that don't have a 'full' connection pattern
# and with ops that have multiple outputs
m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2], mode='FAST_RUN')
results = op2.connection_pattern(None)
expect_result = [[True, False],
[True, True],
[False, True],
[True, True]]
assert results == expect_result
# Inner graph where some computation doesn't rely on explicit inputs
srng = RandomStreams(seed=234)
rv_u = srng.uniform((2,2))
x, y = T.matrices('xy')
out1 = x + rv_u
out2 = y + 3
out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3], mode='FAST_RUN')
results = op3.connection_pattern(None)
expect_result = [[True, False, False],
[False, True, False],
[True, False, True]]
assert results == expect_result
if __name__ == '__main__':
......
......@@ -862,6 +862,74 @@ default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
", ".join(argstrings))
def io_connection_pattern(inputs, outputs):
"""
Returns the connection pattern of a subgraph defined by given
inputs and outputs
"""
inner_nodes = io_toposort(inputs, outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(inputs)
nb_outputs = len(outputs)
for i in range(nb_inputs):
input = inputs[i]
inp_connection_pattern = [i == j for j in range(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = ([[True] * len(n.outputs)] *
len(n.inputs))
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in range(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in range(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [out_connection_pattern[i] or
inp_connection_pattern[i]
for i in range(nb_inputs)]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var[out]
for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
def is_same_graph(var1, var2, givens=None, debug=False):
"""
Return True iff Variables `var1` and `var2` perform the same computation.
......
......@@ -68,7 +68,7 @@ from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano.gof.graph import io_toposort
from theano.gof.graph import io_connection_pattern
from theano.compat import OrderedDict, izip
from theano.tensor import TensorType
from theano.tensor.opt import Shape_i
......@@ -1471,71 +1471,6 @@ class Scan(PureOp):
scan_outs.append((Shape_i(0)(o),) + x[1:])
return scan_outs
def inner_connection_pattern(self):
""" Returns the connection pattern of scan's inner function
"""
inner_nodes = io_toposort(self.inputs, self.outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(self.inputs)
nb_outputs = len(self.outputs)
for i in xrange(nb_inputs):
input = self.inputs[i]
inp_connection_pattern = [i == j for j in xrange(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = ([[True] * len(n.outputs)] *
len(n.inputs))
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in xrange(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in xrange(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [out_connection_pattern[i] or
inp_connection_pattern[i]
for i in xrange(nb_inputs)]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in xrange(len(self.inputs))]
for out in self.outputs:
out_connection_pattern = connect_pattern_by_var[out]
for i in xrange(len(self.inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
def connection_pattern(self, node):
# We cache the result of this function because, with a previous
......@@ -1546,7 +1481,7 @@ class Scan(PureOp):
return node.tag.connection_pattern
# Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern()
inner_connect_pattern = io_connection_pattern(self.inputs, self.outputs)
# Initially assume no outer input is connected to any outer output
connection_pattern = [[False for output in node.outputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论