提交 c3f94f29 authored 作者: --global's avatar --global

Factor scan_nodes_from_fct() out of T_Scan

上级 f705a5d3
......@@ -203,17 +203,18 @@ def grab_scan_node(output):
return rval
def scan_nodes_from_fct(fct):
nodes = fct.maker.fgraph.toposort()
scan_nodes = [n for n in nodes if isinstance(n.op, Scan)]
return scan_nodes
class T_Scan(unittest.TestCase):
def setUp(self):
utt.seed_rng()
super(T_Scan, self).setUp()
def scan_nodes_from_fct(self, fct):
nodes = fct.maker.fgraph.toposort()
scan_nodes = [n for n in nodes if isinstance(n.op, Scan)]
return scan_nodes
# generator network, only one output , type scalar ; no sequence or
# non sequence arguments
@dec.knownfailureif(
......@@ -4261,7 +4262,7 @@ class T_Scan(unittest.TestCase):
f = theano.function([W, v], out, mode=mode_with_opt)
f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2])
scan_nodes = self.scan_nodes_from_fct(f)
scan_nodes = scan_nodes_from_fct(f)
assert len(scan_nodes) == 1
scan_node = scan_nodes[0]
......@@ -4305,7 +4306,7 @@ class T_Scan(unittest.TestCase):
[1, 2],
numpy.zeros((3, 3), theano.config.floatX))
scan_nodes = self.scan_nodes_from_fct(f)
scan_nodes = scan_nodes_from_fct(f)
assert len(scan_nodes) == 1
scan_node = scan_nodes[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论