提交 4cbfea77 authored 作者: ChienliMa's avatar ChienliMa

resotre clone after utils.infer_shape()

上级 124cffee
...@@ -142,14 +142,13 @@ class OpFromGraph(gof.Op): ...@@ -142,14 +142,13 @@ class OpFromGraph(gof.Op):
return io_connection_pattern(self.new_inputs, self.new_outputs) return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
shape = theano.scan_module.scan_utils.infer_shape(self.new_outputs, out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs, self.new_inputs,
shapes) shapes)
replacement = dict([(ori, rpl) for ori, rpl
in izip(self.new_inputs, node.inputs)])
import pdb return [theano.clone(shape, replace=replacement) for shape in out_shp]
pdb.set_trace()
return shape
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for # OpFromGraph doesn't implement a connection_pattern, so for
...@@ -183,5 +182,3 @@ class OpFromGraph(gof.Op): ...@@ -183,5 +182,3 @@ class OpFromGraph(gof.Op):
# Since OpFromGraph contains a Theano compiled function, we should let # Since OpFromGraph contains a Theano compiled function, we should let
# DebugMode know about it # DebugMode know about it
ops_with_inner_function[OpFromGraph] = 'fn' ops_with_inner_function[OpFromGraph] = 'fn'
...@@ -12,146 +12,145 @@ from theano.compile.builders import OpFromGraph ...@@ -12,146 +12,145 @@ from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools from theano.tests import unittest_tools
import unittest
class T_OpFromGraph(unittest_tools.InferShapeTester): class T_OpFromGraph(unittest_tools.InferShapeTester):
# def test_straightforward(self): def test_straightforward(self):
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# e = x + y * z e = x + y * z
# op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
# # (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
# f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# xv = numpy.ones((2, 2), dtype=config.floatX) xv = numpy.ones((2, 2), dtype=config.floatX)
# yv = numpy.ones((2, 2), dtype=config.floatX)*3 yv = numpy.ones((2, 2), dtype=config.floatX)*3
# zv = numpy.ones((2, 2), dtype=config.floatX)*5 zv = numpy.ones((2, 2), dtype=config.floatX)*5
# # print function, function.__module__ # print function, function.__module__
# # print fn.maker.fgraph.toposort() # print fn.maker.fgraph.toposort()
# fn(xv, yv, zv) fn(xv, yv, zv)
# assert numpy.all(8.0 == fn(xv, yv, zv)) assert numpy.all(8.0 == fn(xv, yv, zv))
# assert numpy.all(8.0 == fn(xv, yv, zv)) assert numpy.all(8.0 == fn(xv, yv, zv))
# def test_size_changes(self): def test_size_changes(self):
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# e = T.dot(x, y) e = T.dot(x, y)
# op = OpFromGraph([x, y], [e]) op = OpFromGraph([x, y], [e])
# f = op(x, op(y, z)) f = op(x, op(y, z))
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# xv = numpy.ones((2, 3), dtype=config.floatX) xv = numpy.ones((2, 3), dtype=config.floatX)
# yv = numpy.ones((3, 4), dtype=config.floatX)*3 yv = numpy.ones((3, 4), dtype=config.floatX)*3
# zv = numpy.ones((4, 5), dtype=config.floatX)*5 zv = numpy.ones((4, 5), dtype=config.floatX)*5
# res = fn(xv, yv, zv) res = fn(xv, yv, zv)
# assert res.shape == (2, 5) assert res.shape == (2, 5)
# assert numpy.all(180.0 == res) assert numpy.all(180.0 == res)
# res = fn(xv, yv, zv) res = fn(xv, yv, zv)
# assert res.shape == (2, 5) assert res.shape == (2, 5)
# assert numpy.all(180.0 == res) assert numpy.all(180.0 == res)
# def test_grad(self): def test_grad(self):
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# e = x + y * z e = x + y * z
# op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
# f = op(x, y, z) f = op(x, y, z)
# f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# xv = numpy.ones((2, 2), dtype=config.floatX) xv = numpy.ones((2, 2), dtype=config.floatX)
# yv = numpy.ones((2, 2), dtype=config.floatX)*3 yv = numpy.ones((2, 2), dtype=config.floatX)*3
# zv = numpy.ones((2, 2), dtype=config.floatX)*5 zv = numpy.ones((2, 2), dtype=config.floatX)*5
# assert numpy.all(11.0 == fn(xv, yv, zv)) assert numpy.all(11.0 == fn(xv, yv, zv))
# def test_grad_grad(self): def test_grad_grad(self):
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# e = x + y * z e = x + y * z
# op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
# f = op(x, y, z) f = op(x, y, z)
# f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
# f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# xv = numpy.ones((2, 2), dtype=config.floatX) xv = numpy.ones((2, 2), dtype=config.floatX)
# yv = numpy.ones((2, 2), dtype=config.floatX)*3 yv = numpy.ones((2, 2), dtype=config.floatX)*3
# zv = numpy.ones((2, 2), dtype=config.floatX)*5 zv = numpy.ones((2, 2), dtype=config.floatX)*5
# assert numpy.allclose(6.0, fn(xv, yv, zv)) assert numpy.allclose(6.0, fn(xv, yv, zv))
# def test_shared(self): def test_shared(self):
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(numpy.random.rand(2, 2).astype(config.floatX))
# e = x + y * z + s e = x + y * z + s
# op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
# # (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
# f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# xv = numpy.ones((2, 2), dtype=config.floatX) xv = numpy.ones((2, 2), dtype=config.floatX)
# yv = numpy.ones((2, 2), dtype=config.floatX)*3 yv = numpy.ones((2, 2), dtype=config.floatX)*3
# zv = numpy.ones((2, 2), dtype=config.floatX)*5 zv = numpy.ones((2, 2), dtype=config.floatX)*5
# # print function, function.__module__ # print function, function.__module__
# # print fn.maker.fgraph.toposort() # print fn.maker.fgraph.toposort()
# assert numpy.allclose(8.0, fn(xv, yv, zv)) assert numpy.allclose(8.0, fn(xv, yv, zv))
# assert numpy.allclose(8.0, fn(xv, yv, zv)) assert numpy.allclose(8.0, fn(xv, yv, zv))
# def test_shared_grad(self): def test_shared_grad(self):
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(numpy.random.rand(2, 2).astype(config.floatX))
# e = x + y * z + s e = x + y * z + s
# op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
# f = op(x, y, z) f = op(x, y, z)
# f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# xv = numpy.ones((2, 2), dtype=config.floatX) xv = numpy.ones((2, 2), dtype=config.floatX)
# yv = numpy.ones((2, 2), dtype=config.floatX) * 3 yv = numpy.ones((2, 2), dtype=config.floatX) * 3
# zv = numpy.ones((2, 2), dtype=config.floatX) * 5 zv = numpy.ones((2, 2), dtype=config.floatX) * 5
# assert numpy.allclose(11.0 + s.get_value(), fn(xv, yv, zv)) assert numpy.allclose(11.0 + s.get_value(), fn(xv, yv, zv))
# # grad again the shared variable # grad again the shared variable
# f = op(x, y, z) f = op(x, y, z)
# f = f - T.grad(T.sum(f), s) f = f - T.grad(T.sum(f), s)
# fn = function([x, y, z], f) fn = function([x, y, z], f)
# assert numpy.allclose(15.0 + s.get_value(), assert numpy.allclose(15.0 + s.get_value(),
# fn(xv, yv, zv)) fn(xv, yv, zv))
# def test_connection_pattern(self): def test_connection_pattern(self):
# # Basic case # Basic case
# x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
# out1 = x * y out1 = x * y
# out2 = y * z out2 = y * z
# op1 = OpFromGraph([x ,y, z], [out1, out2]) op1 = OpFromGraph([x ,y, z], [out1, out2])
# results = op1.connection_pattern(None) results = op1.connection_pattern(None)
# expect_result = [[True, False], expect_result = [[True, False],
# [True, True], [True, True],
# [False, True]] [False, True]]
# assert results == expect_result assert results == expect_result
# # Graph with ops that don't have a 'full' connection pattern # Graph with ops that don't have a 'full' connection pattern
# # and with ops that have multiple outputs # and with ops that have multiple outputs
# m, n, p, q = T.matrices('mnpq') m, n, p, q = T.matrices('mnpq')
# o1, o2 = op1(m, n, p) o1, o2 = op1(m, n, p)
# out1, out2 = op1(o1, q, o2) out1, out2 = op1(o1, q, o2)
# op2 = OpFromGraph([m, n, p, q], [out1, out2]) op2 = OpFromGraph([m, n, p, q], [out1, out2])
# results = op2.connection_pattern(None) results = op2.connection_pattern(None)
# expect_result = [[True, False], expect_result = [[True, False],
# [True, True], [True, True],
# [False, True], [False, True],
# [True, True]] [True, True]]
# assert results == expect_result assert results == expect_result
# # Inner graph where some computation doesn't rely on explicit inputs # Inner graph where some computation doesn't rely on explicit inputs
# srng = RandomStreams(seed=234) srng = RandomStreams(seed=234)
# rv_u = srng.uniform((2,2)) rv_u = srng.uniform((2,2))
# x, y = T.matrices('xy') x, y = T.matrices('xy')
# out1 = x + rv_u out1 = x + rv_u
# out2 = y + 3 out2 = y + 3
# out3 = 3 + rv_u out3 = 3 + rv_u
# op3 = OpFromGraph([x, y], [out1, out2, out3]) op3 = OpFromGraph([x, y], [out1, out2, out3])
# results = op3.connection_pattern(None) results = op3.connection_pattern(None)
# expect_result = [[True, False, False], expect_result = [[True, False, False],
# [False, True, False], [False, True, False],
# [True, False, True]] [True, False, True]]
# assert results == expect_result assert results == expect_result
def test_infer_shape(self): def test_infer_shape(self):
x = T.matrix('x') x = T.matrix('x')
...@@ -164,9 +163,6 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -164,9 +163,6 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
p = T.matrix('p') p = T.matrix('p')
self._compile_and_check([q,p], self._compile_and_check([q,p],
[op_graph(q,p)[0],op_graph(q,p)[1]], [op_graph(q,p)[0],op_graph(q,p)[1]],
[numpy.ones([3,4], dtype=config.floatX)], [numpy.ones([3,4], dtype=config.floatX),
numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph) OpFromGraph)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论