提交 cf9d71bb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize tests on cast_policy value

上级 196512e9
...@@ -1848,35 +1848,41 @@ def test_TensorFromScalar(): ...@@ -1848,35 +1848,41 @@ def test_TensorFromScalar():
tensor_from_scalar(vector()) tensor_from_scalar(vector())
def test_ScalarFromTensor(): @pytest.mark.parametrize(
tc = constant(56) # aes.constant(56) "cast_policy",
ss = scalar_from_tensor(tc) [
assert ss.owner.op is scalar_from_tensor "custom",
assert ss.type.dtype == tc.type.dtype "numpy+floatX",
],
v = eval_outputs([ss]) )
def test_ScalarFromTensor(cast_policy):
assert v == 56 with config.change_flags(cast_policy=cast_policy):
assert v.shape == () tc = constant(56) # aes.constant(56)
ss = scalar_from_tensor(tc)
if config.cast_policy == "custom": assert ss.owner.op is scalar_from_tensor
assert isinstance(v, np.int8) assert ss.type.dtype == tc.type.dtype
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert isinstance(v, str(np.asarray(56).dtype)) v = eval_outputs([ss])
else:
raise NotImplementedError(config.cast_policy) assert v == 56
assert v.shape == ()
aes = lscalar()
ss = scalar_from_tensor(aes) if cast_policy == "custom":
ss.owner.op.grad([aes], [ss]) assert isinstance(v, np.int8)
fff = function([aes], ss) elif cast_policy == "numpy+floatX":
v = fff(np.asarray(5)) assert isinstance(v, np.int64)
assert v == 5
assert isinstance(v, np.int64) aes = lscalar()
assert v.shape == () ss = scalar_from_tensor(aes)
ss.owner.op.grad([aes], [ss])
fff = function([aes], ss)
v = fff(np.asarray(5))
assert v == 5
assert isinstance(v, np.int64)
assert v.shape == ()
with pytest.raises(TypeError): with pytest.raises(TypeError):
scalar_from_tensor(vector()) scalar_from_tensor(vector())
class TestOpCache: class TestOpCache:
...@@ -2231,188 +2237,213 @@ class TestARange: ...@@ -2231,188 +2237,213 @@ class TestARange:
rng=rng, rng=rng,
) )
def test_integers(self): @pytest.mark.parametrize(
# Test arange constructor, on integer outputs "cast_policy",
start, stop, step = iscalars("start", "stop", "step") [
out = arange(start, stop, step) "custom",
f = function([start, stop, step], out) "numpy+floatX",
],
)
def test_integers(self, cast_policy):
"""Test arange constructor, on integer outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
if cast_policy == "custom":
assert out.dtype == "int64"
elif cast_policy == "numpy+floatX":
numpy_dtype = np.arange(np.array(1, dtype="int32")).dtype
assert out.dtype == numpy_dtype
assert np.all(f(0, 5, 1) == np.arange(0, 5, 1))
assert np.all(f(2, 11, 4) == np.arange(2, 11, 4))
assert np.all(f(-5, 1, 1) == np.arange(-5, 1, 1))
assert np.all(f(10, 2, -2) == np.arange(10, 2, -2))
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2))
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
if config.cast_policy == "custom": @pytest.mark.parametrize(
assert out.dtype == "int64" "cast_policy",
elif config.cast_policy in ("numpy", "numpy+floatX"): [
numpy_dtype = np.arange(np.array(1, dtype="int32")).dtype "custom",
assert out.dtype == numpy_dtype "numpy+floatX",
else: ],
raise NotImplementedError(config.cast_policy) )
assert np.all(f(0, 5, 1) == np.arange(0, 5, 1)) def test_float32(self, cast_policy):
assert np.all(f(2, 11, 4) == np.arange(2, 11, 4)) """Test arange constructor, on float32 outputs."""
assert np.all(f(-5, 1, 1) == np.arange(-5, 1, 1)) with config.change_flags(cast_policy=cast_policy):
assert np.all(f(10, 2, -2) == np.arange(10, 2, -2)) start, stop, step = fscalars("start", "stop", "step")
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2)) out = arange(start, stop, step)
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1)) f = function([start, stop, step], out)
def test_float32(self): if config.cast_policy == "custom":
# Test arange constructor, on float32 outputs assert out.dtype == start.type.dtype
start, stop, step = fscalars("start", "stop", "step") elif config.cast_policy == "numpy+floatX":
out = arange(start, stop, step) assert out.dtype == config.floatX
f = function([start, stop, step], out)
arg_vals = [
(0, 5, 1),
(2, 11, 4),
(-5, 1.1, 1.2),
(1.3, 2, -2.1),
(10, 2, 2),
]
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom":
expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype
)
elif config.cast_policy == "numpy+floatX":
expected_val = np.arange(
start_v_, stop_v_, step_v_, dtype=out.dtype
)
assert np.all(f_val == expected_val)
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_float64(self, cast_policy):
"""Test arange constructor, on float64 outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = dscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
if config.cast_policy == "custom":
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
elif config.cast_policy == "numpy":
numpy_dtype = np.arange( arg_vals = [
np.array(0, dtype=start.dtype), (0, 5, 1),
np.array(1, dtype=stop.dtype), (2, 11, 4),
np.array(1, dtype=step.dtype), (-5, 1.1, 1.2),
).dtype (1.3, 2, -2.1),
assert out.dtype == numpy_dtype (10, 2, 2),
elif config.cast_policy == "numpy+floatX": ]
assert out.dtype == config.floatX for arg_v in arg_vals:
else: start_v, stop_v, step_v = arg_v
raise NotImplementedError(config.cast_policy) start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
arg_vals = [(0, 5, 1), (2, 11, 4), (-5, 1.1, 1.2), (1.3, 2, -2.1), (10, 2, 2)] f_val = f(start_v_, stop_v_, step_v_)
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v if config.cast_policy == "custom":
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype) expected_val = np.arange(
f_val = f(start_v_, stop_v_, step_v_) start_v, stop_v, step_v, dtype=start.type.dtype
)
elif config.cast_policy == "numpy+floatX":
expected_val = np.arange(start_v_, stop_v_, step_v_)
assert np.all(f_val == expected_val)
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_default_step(self, cast_policy):
"""Test that arange constructor uses the correct default step."""
with config.change_flags(cast_policy=cast_policy):
start, stop = iscalars("start", "stop")
out = arange(start, stop)
f = function([start, stop], out)
if config.cast_policy == "custom": if config.cast_policy == "custom":
expected_val = np.arange( assert out.dtype == "int64"
start_v, stop_v, step_v, dtype=start.type.dtype elif config.cast_policy == "numpy+floatX":
) assert out.dtype == np.arange(np.int32(0), np.int32(1)).dtype
elif config.cast_policy in ("numpy", "numpy+floatX"):
expected_val = np.arange(start_v_, stop_v_, step_v_, dtype=out.dtype)
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f_val == expected_val)
def test_float64(self): assert np.all(f(0, 5) == np.arange(0, 5))
# Test arange constructor, on float64 outputs assert np.all(f(-5, 1) == np.arange(-5, 1))
start, stop, step = dscalars("start", "stop", "step") assert np.all(f(0, 0) == np.arange(0, 0))
out = arange(start, stop, step)
f = function([start, stop, step], out) dstart, dstop = dscalars("start", "stop")
dout = arange(dstart, dstop)
df = function([dstart, dstop], dout)
assert dout.dtype == dstart.type.dtype
# print df(0.2, 5.3)
# print np.arange(0.2, 5.3)
assert np.all(df(0.2, 5.3) == np.arange(0.2, 5.3))
assert np.all(df(0.8, 5.3) == np.arange(0.8, 5.3))
assert np.all(df(-0.7, 5.3) == np.arange(-0.7, 5.3))
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_default_start(self, cast_policy):
"""Test that arange constructor uses the correct default start."""
with config.change_flags(cast_policy=cast_policy):
stop = iscalar("stop")
out = arange(stop)
f = function([stop], out)
assert out.dtype == start.type.dtype
arg_vals = [(0, 5, 1), (2, 11, 4), (-5, 1.1, 1.2), (1.3, 2, -2.1), (10, 2, 2)]
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom": if config.cast_policy == "custom":
expected_val = np.arange( assert out.dtype == "int64"
start_v, stop_v, step_v, dtype=start.type.dtype elif config.cast_policy == "numpy+floatX":
) assert out.dtype == np.arange(np.int32(1)).dtype
elif config.cast_policy in ("numpy", "numpy+floatX"):
expected_val = np.arange(start_v_, stop_v_, step_v_) assert np.all(f(8) == np.arange(8))
else: assert np.all(f(-2) == np.arange(-2))
raise NotImplementedError(config.cast_policy)
assert np.all(f_val == expected_val) fstop = fscalar("stop")
fout = arange(fstop)
def test_default_step(self): ff = function([fstop], fout)
# Test that arange constructor uses the correct default step
start, stop = iscalars("start", "stop") if config.cast_policy == "custom":
out = arange(start, stop) assert fout.dtype == fstop.type.dtype
f = function([start, stop], out) elif config.cast_policy == "numpy+floatX":
if config.floatX == "float32":
if config.cast_policy == "custom": assert fout.dtype == "float32"
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert out.dtype == np.arange(np.int32(0), np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5) == np.arange(0, 5))
assert np.all(f(-5, 1) == np.arange(-5, 1))
assert np.all(f(0, 0) == np.arange(0, 0))
dstart, dstop = dscalars("start", "stop")
dout = arange(dstart, dstop)
df = function([dstart, dstop], dout)
assert dout.dtype == dstart.type.dtype
# print df(0.2, 5.3)
# print np.arange(0.2, 5.3)
assert np.all(df(0.2, 5.3) == np.arange(0.2, 5.3))
assert np.all(df(0.8, 5.3) == np.arange(0.8, 5.3))
assert np.all(df(-0.7, 5.3) == np.arange(-0.7, 5.3))
def test_default_start(self):
# Test that arange constructor uses the correct default start
stop = iscalar("stop")
out = arange(stop)
f = function([stop], out)
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert out.dtype == np.arange(np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(8) == np.arange(8))
assert np.all(f(-2) == np.arange(-2))
fstop = fscalar("stop")
fout = arange(fstop)
ff = function([fstop], fout)
if config.cast_policy == "custom":
assert fout.dtype == fstop.type.dtype
elif config.cast_policy == "numpy":
assert fout.dtype == np.arange(np.float32(1)).dtype
elif config.cast_policy == "numpy+floatX":
if config.floatX == "float32":
assert fout.dtype == "float32"
else:
assert fout.dtype == np.arange(np.float32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
fstop_values = [0.2, -0.7, 8.5]
for fstop_v in fstop_values:
fstop_v32 = np.float32(fstop_v)
assert np.all(ff(fstop_v32) == np.arange(fstop_v))
def test_upcast(self):
# Test that arange computes output type adequately
if config.cast_policy == "custom":
assert arange(iscalar()).dtype == "int64"
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
elif config.cast_policy in ("numpy", "numpy+floatX"):
for dtype in get_numeric_types():
# Test with a single argument.
arange_dtype = arange(scalar(dtype=str(dtype))).dtype
numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype
if (
dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
):
# We want a float32 arange.
assert arange_dtype == "float32"
else: else:
# Follow numpy. assert fout.dtype == np.arange(np.float32(1)).dtype
assert arange_dtype == numpy_dtype
fstop_values = [0.2, -0.7, 8.5]
# Test with two arguments. for fstop_v in fstop_values:
for stop_dtype in get_numeric_types(): fstop_v32 = np.float32(fstop_v)
arange_dtype = arange( assert np.all(ff(fstop_v32) == np.arange(fstop_v))
start=scalar(dtype=str(dtype)),
stop=scalar(dtype=str(stop_dtype)), @pytest.mark.parametrize(
).dtype "cast_policy",
numpy_dtype = np.arange( [
start=np.array(0, dtype=dtype), "custom",
stop=np.array(1, dtype=stop_dtype), "numpy+floatX",
).dtype ],
)
def test_upcast(self, cast_policy):
"""Test that arange computes output type adequately."""
with config.change_flags(cast_policy=cast_policy):
if config.cast_policy == "custom":
assert arange(iscalar()).dtype == "int64"
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
elif config.cast_policy == "numpy+floatX":
for dtype in get_numeric_types():
# Test with a single argument.
arange_dtype = arange(scalar(dtype=str(dtype))).dtype
numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype
if ( if (
dtype != "float64" dtype != "float64"
and stop_dtype != "float64"
and numpy_dtype == "float64" and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX" and config.cast_policy == "numpy+floatX"
and config.floatX == "float32" and config.floatX == "float32"
...@@ -2423,22 +2454,19 @@ class TestARange: ...@@ -2423,22 +2454,19 @@ class TestARange:
# Follow numpy. # Follow numpy.
assert arange_dtype == numpy_dtype assert arange_dtype == numpy_dtype
# Test with three arguments. # Test with two arguments.
for step_dtype in get_numeric_types(): for stop_dtype in get_numeric_types():
arange_dtype = arange( arange_dtype = arange(
start=scalar(dtype=str(dtype)), start=scalar(dtype=str(dtype)),
stop=scalar(dtype=str(stop_dtype)), stop=scalar(dtype=str(stop_dtype)),
step=scalar(dtype=str(step_dtype)),
).dtype ).dtype
numpy_dtype = np.arange( numpy_dtype = np.arange(
start=np.array(0, dtype=dtype), start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype), stop=np.array(1, dtype=stop_dtype),
step=np.array(1, dtype=step_dtype),
).dtype ).dtype
if ( if (
dtype != "float64" dtype != "float64"
and stop_dtype != "float64" and stop_dtype != "float64"
and step_dtype != "float64"
and numpy_dtype == "float64" and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX" and config.cast_policy == "numpy+floatX"
and config.floatX == "float32" and config.floatX == "float32"
...@@ -2448,8 +2476,32 @@ class TestARange: ...@@ -2448,8 +2476,32 @@ class TestARange:
else: else:
# Follow numpy. # Follow numpy.
assert arange_dtype == numpy_dtype assert arange_dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy) # Test with three arguments.
for step_dtype in get_numeric_types():
arange_dtype = arange(
start=scalar(dtype=str(dtype)),
stop=scalar(dtype=str(stop_dtype)),
step=scalar(dtype=str(step_dtype)),
).dtype
numpy_dtype = np.arange(
start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype),
step=np.array(1, dtype=step_dtype),
).dtype
if (
dtype != "float64"
and stop_dtype != "float64"
and step_dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
):
# We want a float32 arange.
assert arange_dtype == "float32"
else:
# Follow numpy.
assert arange_dtype == numpy_dtype
def test_dtype_cache(self): def test_dtype_cache(self):
# Checks that the same Op is returned on repeated calls to arange # Checks that the same Op is returned on repeated calls to arange
...@@ -2465,74 +2517,79 @@ class TestARange: ...@@ -2465,74 +2517,79 @@ class TestARange:
assert out2.owner.op is out3.owner.op assert out2.owner.op is out3.owner.op
assert out3.owner.op is not out4.owner.op assert out3.owner.op is not out4.owner.op
def test_infer_shape(self): @pytest.mark.parametrize(
start, stop, step = iscalars("start", "stop", "step") "cast_policy",
out = arange(start, stop, step) [
mode = config.mode "custom",
if mode == "FAST_COMPILE": "numpy+floatX",
mode = "FAST_RUN" ],
mode = compile.mode.get_mode(mode).excluding("fusion") )
f = function([start, stop, step], out.shape, mode=mode) def test_infer_shape(self, cast_policy):
assert len(f.maker.fgraph.toposort()) == 9 with config.change_flags(cast_policy=cast_policy):
start, stop, step = iscalars("start", "stop", "step")
if config.cast_policy == "custom": out = arange(start, stop, step)
assert out.dtype == "int64" mode = config.mode
elif config.cast_policy in ("numpy", "numpy+floatX"): if mode == "FAST_COMPILE":
numpy_dtype = np.arange( mode = "FAST_RUN"
np.array(0, dtype=start.dtype), mode = compile.mode.get_mode(mode).excluding("fusion")
np.array(1, dtype=stop.dtype), f = function([start, stop, step], out.shape, mode=mode)
np.array(1, dtype=step.dtype), assert len(f.maker.fgraph.toposort()) == 9
).dtype
assert out.dtype == numpy_dtype if config.cast_policy == "custom":
else: assert out.dtype == "int64"
raise NotImplementedError(config.cast_policy) elif config.cast_policy == "numpy+floatX":
numpy_dtype = np.arange(
assert np.all(f(0, 5, 1) == len(np.arange(0, 5, 1))) np.array(0, dtype=start.dtype),
assert np.all(f(2, 11, 4) == len(np.arange(2, 11, 4))) np.array(1, dtype=stop.dtype),
assert np.all(f(-5, 1, 1) == len(np.arange(-5, 1, 1))) np.array(1, dtype=step.dtype),
assert np.all(f(10, 2, -2) == len(np.arange(10, 2, -2))) ).dtype
assert np.all(f(10, 2, 2) == len(np.arange(10, 2, 2))) assert out.dtype == numpy_dtype
assert np.all(f(0, 0, 1) == len(np.arange(0, 0, 1)))
assert np.all(f(0, 5, 1) == len(np.arange(0, 5, 1)))
out = arange(start, stop, 1) assert np.all(f(2, 11, 4) == len(np.arange(2, 11, 4)))
f = function([start, stop], out.shape, mode=mode) assert np.all(f(-5, 1, 1) == len(np.arange(-5, 1, 1)))
assert len(f.maker.fgraph.toposort()) == 5 assert np.all(f(10, 2, -2) == len(np.arange(10, 2, -2)))
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] assert np.all(f(10, 2, 2) == len(np.arange(10, 2, 2)))
if config.cast_policy == "custom": assert np.all(f(0, 0, 1) == len(np.arange(0, 0, 1)))
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): out = arange(start, stop, 1)
assert out.dtype == np.arange(np.int32(0), np.int32(1), np.int32(1)).dtype f = function([start, stop], out.shape, mode=mode)
else: assert len(f.maker.fgraph.toposort()) == 5
raise NotImplementedError(config.cast_policy) # 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert np.all(f(0, 5) == len(np.arange(0, 5))) if config.cast_policy == "custom":
assert np.all(f(2, 11) == len(np.arange(2, 11))) assert out.dtype == "int64"
assert np.all(f(-5, 1) == len(np.arange(-5, 1))) elif config.cast_policy == "numpy+floatX":
assert np.all(f(10, 2) == len(np.arange(10, 2))) assert (
assert np.all(f(10, 2) == len(np.arange(10, 2))) out.dtype == np.arange(np.int32(0), np.int32(1), np.int32(1)).dtype
assert np.all(f(0, 0) == len(np.arange(0, 0))) )
assert np.all(f(-64, 64) == len(np.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128] assert np.all(f(0, 5) == len(np.arange(0, 5)))
assert arange(-64, 64, 2).shape.eval() == [64] assert np.all(f(2, 11) == len(np.arange(2, 11)))
assert np.all(f(-5, 1) == len(np.arange(-5, 1)))
out = arange(0, stop, 1) assert np.all(f(10, 2) == len(np.arange(10, 2)))
f = function([stop], out.shape, mode=mode) assert np.all(f(10, 2) == len(np.arange(10, 2)))
assert len(f.maker.fgraph.toposort()) == 2 assert np.all(f(0, 0) == len(np.arange(0, 0)))
# [Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)] assert np.all(f(-64, 64) == len(np.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128]
if config.cast_policy == "custom": assert arange(-64, 64, 2).shape.eval() == [64]
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"): out = arange(0, stop, 1)
numpy_dtype = np.arange(0, np.array(1, dtype=stop.dtype), 1).dtype f = function([stop], out.shape, mode=mode)
assert out.dtype == numpy_dtype assert len(f.maker.fgraph.toposort()) == 2
else: # [Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
raise NotImplementedError(config.cast_policy)
if config.cast_policy == "custom":
assert np.all(f(5) == len(np.arange(0, 5))) assert out.dtype == "int64"
assert np.all(f(11) == len(np.arange(0, 11))) elif config.cast_policy == "numpy+floatX":
assert np.all(f(1) == len(np.arange(0, 1))) numpy_dtype = np.arange(0, np.array(1, dtype=stop.dtype), 1).dtype
assert np.all(f(2) == len(np.arange(0, 2))) assert out.dtype == numpy_dtype
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(0) == len(np.arange(0, 0))) assert np.all(f(5) == len(np.arange(0, 5)))
assert np.all(f(11) == len(np.arange(0, 11)))
assert np.all(f(1) == len(np.arange(0, 1)))
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(0) == len(np.arange(0, 0)))
class TestNdGrid: class TestNdGrid:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论