提交 cf9d71bb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize tests on cast_policy value

上级 196512e9
...@@ -1848,7 +1848,15 @@ def test_TensorFromScalar(): ...@@ -1848,7 +1848,15 @@ def test_TensorFromScalar():
tensor_from_scalar(vector()) tensor_from_scalar(vector())
def test_ScalarFromTensor(): @pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_ScalarFromTensor(cast_policy):
with config.change_flags(cast_policy=cast_policy):
tc = constant(56) # aes.constant(56) tc = constant(56) # aes.constant(56)
ss = scalar_from_tensor(tc) ss = scalar_from_tensor(tc)
assert ss.owner.op is scalar_from_tensor assert ss.owner.op is scalar_from_tensor
...@@ -1859,12 +1867,10 @@ def test_ScalarFromTensor(): ...@@ -1859,12 +1867,10 @@ def test_ScalarFromTensor():
assert v == 56 assert v == 56
assert v.shape == () assert v.shape == ()
if config.cast_policy == "custom": if cast_policy == "custom":
assert isinstance(v, np.int8) assert isinstance(v, np.int8)
elif config.cast_policy in ("numpy", "numpy+floatX"): elif cast_policy == "numpy+floatX":
assert isinstance(v, str(np.asarray(56).dtype)) assert isinstance(v, np.int64)
else:
raise NotImplementedError(config.cast_policy)
aes = lscalar() aes = lscalar()
ss = scalar_from_tensor(aes) ss = scalar_from_tensor(aes)
...@@ -2231,19 +2237,26 @@ class TestARange: ...@@ -2231,19 +2237,26 @@ class TestARange:
rng=rng, rng=rng,
) )
def test_integers(self): @pytest.mark.parametrize(
# Test arange constructor, on integer outputs "cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_integers(self, cast_policy):
"""Test arange constructor, on integer outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = iscalars("start", "stop", "step") start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step) out = arange(start, stop, step)
f = function([start, stop, step], out) f = function([start, stop, step], out)
if config.cast_policy == "custom": if cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): elif cast_policy == "numpy+floatX":
numpy_dtype = np.arange(np.array(1, dtype="int32")).dtype numpy_dtype = np.arange(np.array(1, dtype="int32")).dtype
assert out.dtype == numpy_dtype assert out.dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5, 1) == np.arange(0, 5, 1)) assert np.all(f(0, 5, 1) == np.arange(0, 5, 1))
assert np.all(f(2, 11, 4) == np.arange(2, 11, 4)) assert np.all(f(2, 11, 4) == np.arange(2, 11, 4))
assert np.all(f(-5, 1, 1) == np.arange(-5, 1, 1)) assert np.all(f(-5, 1, 1) == np.arange(-5, 1, 1))
...@@ -2251,74 +2264,104 @@ class TestARange: ...@@ -2251,74 +2264,104 @@ class TestARange:
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2)) assert np.all(f(10, 2, 2) == np.arange(10, 2, 2))
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1)) assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
def test_float32(self): @pytest.mark.parametrize(
# Test arange constructor, on float32 outputs "cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_float32(self, cast_policy):
"""Test arange constructor, on float32 outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = fscalars("start", "stop", "step") start, stop, step = fscalars("start", "stop", "step")
out = arange(start, stop, step) out = arange(start, stop, step)
f = function([start, stop, step], out) f = function([start, stop, step], out)
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
elif config.cast_policy == "numpy":
numpy_dtype = np.arange(
np.array(0, dtype=start.dtype),
np.array(1, dtype=stop.dtype),
np.array(1, dtype=step.dtype),
).dtype
assert out.dtype == numpy_dtype
elif config.cast_policy == "numpy+floatX": elif config.cast_policy == "numpy+floatX":
assert out.dtype == config.floatX assert out.dtype == config.floatX
else:
raise NotImplementedError(config.cast_policy) arg_vals = [
arg_vals = [(0, 5, 1), (2, 11, 4), (-5, 1.1, 1.2), (1.3, 2, -2.1), (10, 2, 2)] (0, 5, 1),
(2, 11, 4),
(-5, 1.1, 1.2),
(1.3, 2, -2.1),
(10, 2, 2),
]
for arg_v in arg_vals: for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype) start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_) f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom": if config.cast_policy == "custom":
expected_val = np.arange( expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype start_v, stop_v, step_v, dtype=start.type.dtype
) )
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
expected_val = np.arange(start_v_, stop_v_, step_v_, dtype=out.dtype) expected_val = np.arange(
else: start_v_, stop_v_, step_v_, dtype=out.dtype
raise NotImplementedError(config.cast_policy) )
assert np.all(f_val == expected_val) assert np.all(f_val == expected_val)
def test_float64(self): @pytest.mark.parametrize(
# Test arange constructor, on float64 outputs "cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_float64(self, cast_policy):
"""Test arange constructor, on float64 outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = dscalars("start", "stop", "step") start, stop, step = dscalars("start", "stop", "step")
out = arange(start, stop, step) out = arange(start, stop, step)
f = function([start, stop, step], out) f = function([start, stop, step], out)
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
arg_vals = [(0, 5, 1), (2, 11, 4), (-5, 1.1, 1.2), (1.3, 2, -2.1), (10, 2, 2)]
arg_vals = [
(0, 5, 1),
(2, 11, 4),
(-5, 1.1, 1.2),
(1.3, 2, -2.1),
(10, 2, 2),
]
for arg_v in arg_vals: for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype) start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_) f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom": if config.cast_policy == "custom":
expected_val = np.arange( expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype start_v, stop_v, step_v, dtype=start.type.dtype
) )
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
expected_val = np.arange(start_v_, stop_v_, step_v_) expected_val = np.arange(start_v_, stop_v_, step_v_)
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f_val == expected_val) assert np.all(f_val == expected_val)
def test_default_step(self): @pytest.mark.parametrize(
# Test that arange constructor uses the correct default step "cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_default_step(self, cast_policy):
"""Test that arange constructor uses the correct default step."""
with config.change_flags(cast_policy=cast_policy):
start, stop = iscalars("start", "stop") start, stop = iscalars("start", "stop")
out = arange(start, stop) out = arange(start, stop)
f = function([start, stop], out) f = function([start, stop], out)
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
assert out.dtype == np.arange(np.int32(0), np.int32(1)).dtype assert out.dtype == np.arange(np.int32(0), np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5) == np.arange(0, 5)) assert np.all(f(0, 5) == np.arange(0, 5))
assert np.all(f(-5, 1) == np.arange(-5, 1)) assert np.all(f(-5, 1) == np.arange(-5, 1))
assert np.all(f(0, 0) == np.arange(0, 0)) assert np.all(f(0, 0) == np.arange(0, 0))
...@@ -2334,18 +2377,25 @@ class TestARange: ...@@ -2334,18 +2377,25 @@ class TestARange:
assert np.all(df(0.8, 5.3) == np.arange(0.8, 5.3)) assert np.all(df(0.8, 5.3) == np.arange(0.8, 5.3))
assert np.all(df(-0.7, 5.3) == np.arange(-0.7, 5.3)) assert np.all(df(-0.7, 5.3) == np.arange(-0.7, 5.3))
def test_default_start(self): @pytest.mark.parametrize(
# Test that arange constructor uses the correct default start "cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_default_start(self, cast_policy):
"""Test that arange constructor uses the correct default start."""
with config.change_flags(cast_policy=cast_policy):
stop = iscalar("stop") stop = iscalar("stop")
out = arange(stop) out = arange(stop)
f = function([stop], out) f = function([stop], out)
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
assert out.dtype == np.arange(np.int32(1)).dtype assert out.dtype == np.arange(np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(8) == np.arange(8)) assert np.all(f(8) == np.arange(8))
assert np.all(f(-2) == np.arange(-2)) assert np.all(f(-2) == np.arange(-2))
...@@ -2355,23 +2405,27 @@ class TestARange: ...@@ -2355,23 +2405,27 @@ class TestARange:
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert fout.dtype == fstop.type.dtype assert fout.dtype == fstop.type.dtype
elif config.cast_policy == "numpy":
assert fout.dtype == np.arange(np.float32(1)).dtype
elif config.cast_policy == "numpy+floatX": elif config.cast_policy == "numpy+floatX":
if config.floatX == "float32": if config.floatX == "float32":
assert fout.dtype == "float32" assert fout.dtype == "float32"
else: else:
assert fout.dtype == np.arange(np.float32(1)).dtype assert fout.dtype == np.arange(np.float32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
fstop_values = [0.2, -0.7, 8.5] fstop_values = [0.2, -0.7, 8.5]
for fstop_v in fstop_values: for fstop_v in fstop_values:
fstop_v32 = np.float32(fstop_v) fstop_v32 = np.float32(fstop_v)
assert np.all(ff(fstop_v32) == np.arange(fstop_v)) assert np.all(ff(fstop_v32) == np.arange(fstop_v))
def test_upcast(self): @pytest.mark.parametrize(
# Test that arange computes output type adequately "cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_upcast(self, cast_policy):
"""Test that arange computes output type adequately."""
with config.change_flags(cast_policy=cast_policy):
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert arange(iscalar()).dtype == "int64" assert arange(iscalar()).dtype == "int64"
assert arange(fscalar()).dtype == fscalar().dtype assert arange(fscalar()).dtype == fscalar().dtype
...@@ -2383,7 +2437,7 @@ class TestARange: ...@@ -2383,7 +2437,7 @@ class TestARange:
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
for dtype in get_numeric_types(): for dtype in get_numeric_types():
# Test with a single argument. # Test with a single argument.
arange_dtype = arange(scalar(dtype=str(dtype))).dtype arange_dtype = arange(scalar(dtype=str(dtype))).dtype
...@@ -2448,8 +2502,6 @@ class TestARange: ...@@ -2448,8 +2502,6 @@ class TestARange:
else: else:
# Follow numpy. # Follow numpy.
assert arange_dtype == numpy_dtype assert arange_dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
def test_dtype_cache(self): def test_dtype_cache(self):
# Checks that the same Op is returned on repeated calls to arange # Checks that the same Op is returned on repeated calls to arange
...@@ -2465,7 +2517,15 @@ class TestARange: ...@@ -2465,7 +2517,15 @@ class TestARange:
assert out2.owner.op is out3.owner.op assert out2.owner.op is out3.owner.op
assert out3.owner.op is not out4.owner.op assert out3.owner.op is not out4.owner.op
def test_infer_shape(self): @pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_infer_shape(self, cast_policy):
with config.change_flags(cast_policy=cast_policy):
start, stop, step = iscalars("start", "stop", "step") start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step) out = arange(start, stop, step)
mode = config.mode mode = config.mode
...@@ -2477,15 +2537,13 @@ class TestARange: ...@@ -2477,15 +2537,13 @@ class TestARange:
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
numpy_dtype = np.arange( numpy_dtype = np.arange(
np.array(0, dtype=start.dtype), np.array(0, dtype=start.dtype),
np.array(1, dtype=stop.dtype), np.array(1, dtype=stop.dtype),
np.array(1, dtype=step.dtype), np.array(1, dtype=step.dtype),
).dtype ).dtype
assert out.dtype == numpy_dtype assert out.dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5, 1) == len(np.arange(0, 5, 1))) assert np.all(f(0, 5, 1) == len(np.arange(0, 5, 1)))
assert np.all(f(2, 11, 4) == len(np.arange(2, 11, 4))) assert np.all(f(2, 11, 4) == len(np.arange(2, 11, 4)))
...@@ -2500,10 +2558,11 @@ class TestARange: ...@@ -2500,10 +2558,11 @@ class TestARange:
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] # 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
assert out.dtype == np.arange(np.int32(0), np.int32(1), np.int32(1)).dtype assert (
else: out.dtype == np.arange(np.int32(0), np.int32(1), np.int32(1)).dtype
raise NotImplementedError(config.cast_policy) )
assert np.all(f(0, 5) == len(np.arange(0, 5))) assert np.all(f(0, 5) == len(np.arange(0, 5)))
assert np.all(f(2, 11) == len(np.arange(2, 11))) assert np.all(f(2, 11) == len(np.arange(2, 11)))
assert np.all(f(-5, 1) == len(np.arange(-5, 1))) assert np.all(f(-5, 1) == len(np.arange(-5, 1)))
...@@ -2521,11 +2580,9 @@ class TestARange: ...@@ -2521,11 +2580,9 @@ class TestARange:
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): elif config.cast_policy == "numpy+floatX":
numpy_dtype = np.arange(0, np.array(1, dtype=stop.dtype), 1).dtype numpy_dtype = np.arange(0, np.array(1, dtype=stop.dtype), 1).dtype
assert out.dtype == numpy_dtype assert out.dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(5) == len(np.arange(0, 5))) assert np.all(f(5) == len(np.arange(0, 5)))
assert np.all(f(11) == len(np.arange(0, 11))) assert np.all(f(11) == len(np.arange(0, 11)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论