提交 2ddaca06 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3117 from ChienliMa/infer_shape

OpFromGraph.infer_shape()
......@@ -6,6 +6,8 @@ from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern
from functools import reduce
class OpFromGraph(gof.Op):
"""This creates an `Op` from inputs and outputs lists of variables.
......@@ -141,6 +143,29 @@ class OpFromGraph(gof.Op):
"""
return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs,
shapes)
# Clone the output shape so that shape are computed from outer inputs.
# Note:
# Here we can do it more simply like:
# ret = [theano.clone(shp, replace=repl) for shp in out_shp]
# But doing it multiple time could duplicate common subgraph between
# each shape call. Theano optimizer will clean this up later, but this
# will ask extra work to the optimizer.
repl = dict(zip(self.new_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = []
used = 0
for i in range(len(out_shp)):
nb = len(out_shp[i])
ret.append(cloned[used: used + nb])
used += nb
return ret
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will
......
import numpy
import unittest
from theano import config, shared
......@@ -11,13 +10,15 @@ from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools
class T_OpFromGraph(unittest.TestCase):
class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_straightforward(self):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
op = OpFromGraph([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x)
......@@ -34,7 +35,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_size_changes(self):
x, y, z = T.matrices('xyz')
e = T.dot(x, y)
op = OpFromGraph([x, y], [e], mode='FAST_RUN')
op = OpFromGraph([x, y], [e])
f = op(x, op(y, z))
fn = function([x, y, z], f)
xv = numpy.ones((2, 3), dtype=config.floatX)
......@@ -50,7 +51,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_grad(self):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
op = OpFromGraph([x, y, z], [e])
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f)
......@@ -62,7 +63,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_grad_grad(self):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
op = OpFromGraph([x, y, z], [e])
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
f = f - T.grad(T.sum(f), y)
......@@ -76,7 +77,7 @@ class T_OpFromGraph(unittest.TestCase):
x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
op = OpFromGraph([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x)
......@@ -93,7 +94,7 @@ class T_OpFromGraph(unittest.TestCase):
x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
op = OpFromGraph([x, y, z], [e])
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f)
......@@ -115,7 +116,7 @@ class T_OpFromGraph(unittest.TestCase):
out1 = x * y
out2 = y * z
op1 = OpFromGraph([x ,y, z], [out1, out2], mode='FAST_RUN')
op1 = OpFromGraph([x ,y, z], [out1, out2])
results = op1.connection_pattern(None)
expect_result = [[True, False],
[True, True],
......@@ -127,7 +128,7 @@ class T_OpFromGraph(unittest.TestCase):
m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2], mode='FAST_RUN')
op2 = OpFromGraph([m, n, p, q], [out1, out2])
results = op2.connection_pattern(None)
expect_result = [[True, False],
......@@ -143,7 +144,7 @@ class T_OpFromGraph(unittest.TestCase):
out1 = x + rv_u
out2 = y + 3
out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3], mode='FAST_RUN')
op3 = OpFromGraph([x, y], [out1, out2, out3])
results = op3.connection_pattern(None)
expect_result = [[True, False, False],
......@@ -151,6 +152,17 @@ class T_OpFromGraph(unittest.TestCase):
[True, False, True]]
assert results == expect_result
if __name__ == '__main__':
unittest.main()
def test_infer_shape(self):
x = T.matrix('x')
y = T.matrix('y')
o1 = x+y
o2 = x*y
op_graph = OpFromGraph([x,y], [o1,o2])
q = T.matrix('q')
p = T.matrix('p')
self._compile_and_check([q,p],
op_graph(q,p),
[numpy.ones([3,4], dtype=config.floatX),
numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论