提交 d749402a authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test in fast_compile mode.

上级 921a9604
......@@ -1391,6 +1391,7 @@ class T_subtensor(unittest.TestCase):
self.mode = mode
self.dtype = dtype
self.ignore_topo = ignore_topo
self.fast_compile = theano.config.mode == 'FAST_COMPILE'
return super(T_subtensor, self).__init__(name)
def setUp(self):
......@@ -1620,7 +1621,8 @@ class T_subtensor(unittest.TestCase):
f = inplace_func([], gn, mode=self.mode)
topo = f.maker.env.toposort()
topo_ = [node for node in topo if not isinstance(node.op, self.ignore_topo)]
assert len(topo_)==6
if not self.fast_compile:
assert len(topo_)==6
assert numpy.sum([isinstance(node.op, self.inc_sub) for node in topo_])==1
assert numpy.sum([isinstance(node.op, self.sub) for node in topo_])==1
gval = f()
......@@ -1637,7 +1639,8 @@ class T_subtensor(unittest.TestCase):
f = function([], gn, mode=self.mode)
topo = f.maker.env.toposort()
topo_ = [node for node in topo if not isinstance(node.op, self.ignore_topo)]
assert len(topo_)==6
if not self.fast_compile:
assert len(topo_)==6
assert numpy.sum([isinstance(node.op, self.inc_sub) for node in topo_])==1
assert numpy.sum([isinstance(node.op, self.sub) for node in topo_])==1
......@@ -1688,7 +1691,6 @@ class T_subtensor(unittest.TestCase):
def grad_list_(self, idxs, data):
n = self.shared(data)
fast_compile = theano.config.mode == 'FAST_COMPILE'
for idx in idxs:
# Should stay on the cpu.
......@@ -1697,7 +1699,7 @@ class T_subtensor(unittest.TestCase):
gn = grad(sum(exp(t)), n)
f = function([], [gn, gn.shape], mode=self.mode)
topo = f.maker.env.toposort()
if not fast_compile:
if not self.fast_compile:
assert any([isinstance(node.op, self.adv_incsub1) and node.op.inplace for node in topo])
else:
assert any([isinstance(node.op, self.adv_incsub1) for node in topo])
......@@ -1724,7 +1726,7 @@ class T_subtensor(unittest.TestCase):
if idx is idxs[0]:
f = function([], [gn.shape, n[idx_].shape], mode=self.mode)
topo = f.maker.env.toposort()
if not fast_compile:
if not self.fast_compile:
self.failUnless(not any([isinstance(node.op, self.adv_incsub1) for node in topo]))
self.failUnless(not any([isinstance(node.op, self.adv_sub1) for node in topo]))
f()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论