提交 3673950e authored 作者: ChienliMa's avatar ChienliMa

infer_shape reuse scan.utils.infer_shape

上级 4725426b
......@@ -11,8 +11,10 @@ from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools
class T_OpFromGraph(unittest.TestCase):
class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_straightforward(self):
x, y, z = T.matrices('xyz')
......@@ -155,9 +157,10 @@ class T_OpFromGraph(unittest.TestCase):
x = T.matrix('x')
y = x+x
op_graph = OpFromGraph([x], [y], mode='FAST_RUN')
shapes = op_graph.infer_shape(None, [(5, 5)])
assert shapes[0] == (5, 5)
self._compile_and_check([x],
[op_graph(x)],
[numpy.ones([3,4])],
OpFromGraph)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论