提交 85c8918a authored 作者: --global's avatar --global

Fix how some tests get a reference to scan nodes in a theano function

上级 93cc458d
...@@ -208,6 +208,11 @@ class T_Scan(unittest.TestCase): ...@@ -208,6 +208,11 @@ class T_Scan(unittest.TestCase):
utt.seed_rng() utt.seed_rng()
super(T_Scan, self).setUp() super(T_Scan, self).setUp()
def scan_nodes_from_fct(self, fct):
nodes = fct.maker.fgraph.toposort()
scan_nodes = [n for n in nodes if isinstance(n.op, Scan)]
return scan_nodes
# generator network, only one output , type scalar ; no sequence or # generator network, only one output , type scalar ; no sequence or
# non sequence arguments # non sequence arguments
@dec.knownfailureif( @dec.knownfailureif(
...@@ -4047,7 +4052,7 @@ class T_Scan(unittest.TestCase): ...@@ -4047,7 +4052,7 @@ class T_Scan(unittest.TestCase):
# This used to raise an exception # This used to raise an exception
f = theano.function([W, v], out, mode=mode_with_opt) f = theano.function([W, v], out, mode=mode_with_opt)
f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2]) f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2])
scan_node = f.maker.fgraph.toposort()[-1] scan_node = self.scan_nodes_from_fct(f)[0]
# The first input is the number of iteration. # The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) == assert (len(scan_node.inputs[1:]) ==
...@@ -4088,7 +4093,7 @@ class T_Scan(unittest.TestCase): ...@@ -4088,7 +4093,7 @@ class T_Scan(unittest.TestCase):
f(numpy.zeros((3, 3), theano.config.floatX), f(numpy.zeros((3, 3), theano.config.floatX),
[1, 2], [1, 2],
numpy.zeros((3, 3), theano.config.floatX)) numpy.zeros((3, 3), theano.config.floatX))
scan_node = f.maker.fgraph.toposort()[-1] scan_node = self.scan_nodes_from_fct(f)[0]
# The first input is the number of iteration. # The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) == assert (len(scan_node.inputs[1:]) ==
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论