提交 f8b334b8 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix test in FAST_COMPILE mode.

上级 5b788507
...@@ -242,7 +242,11 @@ class T_Scan(unittest.TestCase): ...@@ -242,7 +242,11 @@ class T_Scan(unittest.TestCase):
W_in = theano.tensor.fscalar('win') W_in = theano.tensor.fscalar('win')
W = theano.tensor.fscalar('w') W = theano.tensor.fscalar('w')
mode = theano.compile.mode.get_default_mode().including('gpu') if theano.config.mode == 'FAST_COMPILE':
mode = theano.compile.mode.get_mode('FAST_RUN')
else:
mode = theano.compile.mode.get_default_mode()
mode = mode.including('gpu','scan')
# The following line is needed to have the first case being used # The following line is needed to have the first case being used
# Otherwise, it is the second that is tested. # Otherwise, it is the second that is tested.
mode = mode.excluding('InputToGpuOptimizer') mode = mode.excluding('InputToGpuOptimizer')
...@@ -314,7 +318,11 @@ class T_Scan(unittest.TestCase): ...@@ -314,7 +318,11 @@ class T_Scan(unittest.TestCase):
x0 = theano.tensor.fscalar('x0') x0 = theano.tensor.fscalar('x0')
W_in = theano.tensor.fscalar('win') W_in = theano.tensor.fscalar('win')
W = theano.tensor.fscalar('w') W = theano.tensor.fscalar('w')
mode = theano.compile.mode.get_default_mode().including('gpu') if theano.config.mode == 'FAST_COMPILE':
mode = theano.compile.mode.get_mode('FAST_RUN')
else:
mode = theano.compile.mode.get_default_mode()
mode = mode.including('gpu','scan')
output, updates = theano.scan(f_rnn, u,x0,[W_in,W] output, updates = theano.scan(f_rnn, u,x0,[W_in,W]
, n_steps = None , n_steps = None
, truncate_gradient = -1 , truncate_gradient = -1
...@@ -1980,7 +1988,8 @@ class T_Scan(unittest.TestCase): ...@@ -1980,7 +1988,8 @@ class T_Scan(unittest.TestCase):
self.assertTrue(nb_scan == 2) self.assertTrue(nb_scan == 2)
nb_shape_i = len([n for n in topo nb_shape_i = len([n for n in topo
if isinstance(n.op, theano.tensor.opt.Shape_i)]) if isinstance(n.op, theano.tensor.opt.Shape_i)])
self.assertTrue(nb_shape_i == 1) if theano.config.mode != 'FAST_COMPILE':
self.assertTrue(nb_shape_i == 1)
def test_bug_josh_reported(self): def test_bug_josh_reported(self):
import theano import theano
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论