提交 2ddaca06 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3117 from ChienliMa/infer_shape

OpFromGraph.infer_shape()
...@@ -6,6 +6,8 @@ from theano.compile import SharedVariable, rebuild_collect_shared ...@@ -6,6 +6,8 @@ from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from functools import reduce
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
"""This creates an `Op` from inputs and outputs lists of variables. """This creates an `Op` from inputs and outputs lists of variables.
...@@ -141,6 +143,29 @@ class OpFromGraph(gof.Op): ...@@ -141,6 +143,29 @@ class OpFromGraph(gof.Op):
""" """
return io_connection_pattern(self.new_inputs, self.new_outputs) return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs,
shapes)
# Clone the output shape so that shape are computed from outer inputs.
# Note:
# Here we can do it more simply like:
# ret = [theano.clone(shp, replace=repl) for shp in out_shp]
# But doing it multiple time could duplicate common subgraph between
# each shape call. Theano optimizer will clean this up later, but this
# will ask extra work to the optimizer.
repl = dict(zip(self.new_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = []
used = 0
for i in range(len(out_shp)):
nb = len(out_shp[i])
ret.append(cloned[used: used + nb])
used += nb
return ret
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for # OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will # now we regard all inputs and outputs as connected. This will
......
import numpy import numpy
import unittest
from theano import config, shared from theano import config, shared
...@@ -11,13 +10,15 @@ from theano.tensor.shared_randomstreams import RandomStreams ...@@ -11,13 +10,15 @@ from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraph from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools
class T_OpFromGraph(unittest.TestCase):
class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_straightforward(self): def test_straightforward(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN') op = OpFromGraph([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
...@@ -34,7 +35,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -34,7 +35,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_size_changes(self): def test_size_changes(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = T.dot(x, y) e = T.dot(x, y)
op = OpFromGraph([x, y], [e], mode='FAST_RUN') op = OpFromGraph([x, y], [e])
f = op(x, op(y, z)) f = op(x, op(y, z))
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = numpy.ones((2, 3), dtype=config.floatX) xv = numpy.ones((2, 3), dtype=config.floatX)
...@@ -50,7 +51,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -50,7 +51,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_grad(self): def test_grad(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN') op = OpFromGraph([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -62,7 +63,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -62,7 +63,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_grad_grad(self): def test_grad_grad(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN') op = OpFromGraph([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
...@@ -76,7 +77,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -76,7 +77,7 @@ class T_OpFromGraph(unittest.TestCase):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN') op = OpFromGraph([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
...@@ -93,7 +94,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -93,7 +94,7 @@ class T_OpFromGraph(unittest.TestCase):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN') op = OpFromGraph([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -115,7 +116,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -115,7 +116,7 @@ class T_OpFromGraph(unittest.TestCase):
out1 = x * y out1 = x * y
out2 = y * z out2 = y * z
op1 = OpFromGraph([x ,y, z], [out1, out2], mode='FAST_RUN') op1 = OpFromGraph([x ,y, z], [out1, out2])
results = op1.connection_pattern(None) results = op1.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
[True, True], [True, True],
...@@ -127,7 +128,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -127,7 +128,7 @@ class T_OpFromGraph(unittest.TestCase):
m, n, p, q = T.matrices('mnpq') m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p) o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2) out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2], mode='FAST_RUN') op2 = OpFromGraph([m, n, p, q], [out1, out2])
results = op2.connection_pattern(None) results = op2.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
...@@ -143,7 +144,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -143,7 +144,7 @@ class T_OpFromGraph(unittest.TestCase):
out1 = x + rv_u out1 = x + rv_u
out2 = y + 3 out2 = y + 3
out3 = 3 + rv_u out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3], mode='FAST_RUN') op3 = OpFromGraph([x, y], [out1, out2, out3])
results = op3.connection_pattern(None) results = op3.connection_pattern(None)
expect_result = [[True, False, False], expect_result = [[True, False, False],
...@@ -151,6 +152,17 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -151,6 +152,17 @@ class T_OpFromGraph(unittest.TestCase):
[True, False, True]] [True, False, True]]
assert results == expect_result assert results == expect_result
def test_infer_shape(self):
if __name__ == '__main__': x = T.matrix('x')
unittest.main() y = T.matrix('y')
o1 = x+y
o2 = x*y
op_graph = OpFromGraph([x,y], [o1,o2])
q = T.matrix('q')
p = T.matrix('p')
self._compile_and_check([q,p],
op_graph(q,p),
[numpy.ones([3,4], dtype=config.floatX),
numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论