提交 9c386cef authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the arange tests to expect the right dtype.

上级 f1727ebe
...@@ -5316,7 +5316,7 @@ class TestARange(unittest.TestCase): ...@@ -5316,7 +5316,7 @@ class TestARange(unittest.TestCase):
f = function([start, stop, step], out) f = function([start, stop, step], out)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(numpy.array(1, dtype='int32')).dtype numpy_dtype = numpy.arange(numpy.array(1, dtype='int32')).dtype
assert out.dtype == numpy_dtype assert out.dtype == numpy_dtype
...@@ -5393,7 +5393,7 @@ class TestARange(unittest.TestCase): ...@@ -5393,7 +5393,7 @@ class TestARange(unittest.TestCase):
f = function([start, stop], out) f = function([start, stop], out)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(numpy.int32(0), assert out.dtype == numpy.arange(numpy.int32(0),
numpy.int32(1)).dtype numpy.int32(1)).dtype
...@@ -5421,7 +5421,7 @@ class TestARange(unittest.TestCase): ...@@ -5421,7 +5421,7 @@ class TestARange(unittest.TestCase):
f = function([stop], out) f = function([stop], out)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == stop.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange(numpy.int32(1)).dtype assert out.dtype == numpy.arange(numpy.int32(1)).dtype
else: else:
...@@ -5453,7 +5453,7 @@ class TestARange(unittest.TestCase): ...@@ -5453,7 +5453,7 @@ class TestARange(unittest.TestCase):
def test_upcast(self): def test_upcast(self):
"""Test that arange computes output type adequately""" """Test that arange computes output type adequately"""
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert arange(iscalar()).dtype == iscalar().dtype assert arange(iscalar()).dtype == 'int64'
assert arange(fscalar()).dtype == fscalar().dtype assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype assert arange(dscalar()).dtype == dscalar().dtype
...@@ -5547,7 +5547,7 @@ class TestARange(unittest.TestCase): ...@@ -5547,7 +5547,7 @@ class TestARange(unittest.TestCase):
assert len(f.maker.fgraph.toposort()) == 9 assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(numpy.array(0, dtype=start.dtype), numpy_dtype = numpy.arange(numpy.array(0, dtype=start.dtype),
numpy.array(1, dtype=stop.dtype), numpy.array(1, dtype=stop.dtype),
...@@ -5568,7 +5568,7 @@ class TestARange(unittest.TestCase): ...@@ -5568,7 +5568,7 @@ class TestARange(unittest.TestCase):
assert len(f.maker.fgraph.toposort()) == 5 assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] # 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
assert out.dtype == numpy.arange( assert out.dtype == numpy.arange(
numpy.int32(0), numpy.int32(1), numpy.int32(1)).dtype numpy.int32(0), numpy.int32(1), numpy.int32(1)).dtype
...@@ -5590,7 +5590,7 @@ class TestARange(unittest.TestCase): ...@@ -5590,7 +5590,7 @@ class TestARange(unittest.TestCase):
#[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)] #[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
assert out.dtype == start.type.dtype assert out.dtype == 'int64'
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
numpy_dtype = numpy.arange(0, numpy_dtype = numpy.arange(0,
numpy.array(1, dtype=stop.dtype), numpy.array(1, dtype=stop.dtype),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论