提交 22c71c9e authored 作者: --global's avatar --global

Add assert to validate the number of scan nodes

上级 85c8918a
......@@ -4052,7 +4052,10 @@ class T_Scan(unittest.TestCase):
# This used to raise an exception
f = theano.function([W, v], out, mode=mode_with_opt)
f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2])
scan_node = self.scan_nodes_from_fct(f)[0]
scan_nodes = self.scan_nodes_from_fct(f)
assert len(scan_nodes) == 1
scan_node = scan_nodes[0]
# The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) ==
......@@ -4093,7 +4096,10 @@ class T_Scan(unittest.TestCase):
f(numpy.zeros((3, 3), theano.config.floatX),
[1, 2],
numpy.zeros((3, 3), theano.config.floatX))
scan_node = self.scan_nodes_from_fct(f)[0]
scan_nodes = self.scan_nodes_from_fct(f)
assert len(scan_nodes) == 1
scan_node = scan_nodes[0]
# The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) ==
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论