提交 9b635db7 authored 作者: ChienliMa's avatar ChienliMa

add other test case

上级 12b09e99
...@@ -7,6 +7,7 @@ from theano.compile import function ...@@ -7,6 +7,7 @@ from theano.compile import function
from theano import tensor from theano import tensor
from theano import tensor as T from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraph from theano.compile.builders import OpFromGraph
...@@ -109,17 +110,46 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -109,17 +110,46 @@ class T_OpFromGraph(unittest.TestCase):
fn(xv, yv, zv)) fn(xv, yv, zv))
def test_connection_pattern(self): def test_connection_pattern(self):
# Basic case
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
out1 = x * y out1 = x * y
out2 = y * z out2 = y * z
op = OpFromGraph([x ,y, z], [out1, out2], mode='FAST_RUN') op1 = OpFromGraph([x ,y, z], [out1, out2], mode='FAST_RUN')
results = op.connection_pattern(None) results = op1.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
[True, True], [True, True],
[False, True]] [False, True]]
assert results == expect_result assert results == expect_result
# Graph with ops that don't have a 'full' connection pattern
# and with ops that have multiple outputs
m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2], mode='FAST_RUN')
results = op2.connection_pattern(None)
expect_result = [[True, False],
[True, True],
[False, True],
[True, True]]
assert results == expect_result
# Inner graph where some computation doesn't rely on explicit inputs
srng = RandomStreams(seed=234)
rv_u = srng.uniform((2,2))
x, y = T.matrices('xy')
out1 = x + rv_u
out2 = y + 3
out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3], mode='FAST_RUN')
results = op2.connection_pattern(None)
expect_result = [[True, False, False],
[False, True, False]]
assert results == expect_result
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论