提交 056f4fa1 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test in fast_compile mode

上级 45696520
......@@ -2074,7 +2074,11 @@ class TestARange(unittest.TestCase):
def test_infer_shape(self):
start, stop, step = iscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).excluding('fusion')
f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.env.toposort())==7
#7 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{float64}}(Elemwise{sub,no_inplace}.0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, step), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
......@@ -2087,7 +2091,7 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(0,0,1) == len(numpy.arange(0,0,1)))
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.env.toposort())==4
#4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert out.dtype == start.type.dtype
......@@ -2099,7 +2103,7 @@ class TestARange(unittest.TestCase):
assert numpy.all(f(0,0) == len(numpy.arange(0,0)))
out = arange(0, stop, 1)
f = function([stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
f = function([stop], out.shape, mode=mode)
assert len(f.maker.env.toposort())==2
#[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论