提交 cf9d71bb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize tests on cast_policy value

上级 196512e9
......@@ -1848,35 +1848,41 @@ def test_TensorFromScalar():
tensor_from_scalar(vector())
def test_ScalarFromTensor():
tc = constant(56) # aes.constant(56)
ss = scalar_from_tensor(tc)
assert ss.owner.op is scalar_from_tensor
assert ss.type.dtype == tc.type.dtype
v = eval_outputs([ss])
assert v == 56
assert v.shape == ()
if config.cast_policy == "custom":
assert isinstance(v, np.int8)
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert isinstance(v, str(np.asarray(56).dtype))
else:
raise NotImplementedError(config.cast_policy)
aes = lscalar()
ss = scalar_from_tensor(aes)
ss.owner.op.grad([aes], [ss])
fff = function([aes], ss)
v = fff(np.asarray(5))
assert v == 5
assert isinstance(v, np.int64)
assert v.shape == ()
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_ScalarFromTensor(cast_policy):
with config.change_flags(cast_policy=cast_policy):
tc = constant(56) # aes.constant(56)
ss = scalar_from_tensor(tc)
assert ss.owner.op is scalar_from_tensor
assert ss.type.dtype == tc.type.dtype
v = eval_outputs([ss])
assert v == 56
assert v.shape == ()
if cast_policy == "custom":
assert isinstance(v, np.int8)
elif cast_policy == "numpy+floatX":
assert isinstance(v, np.int64)
aes = lscalar()
ss = scalar_from_tensor(aes)
ss.owner.op.grad([aes], [ss])
fff = function([aes], ss)
v = fff(np.asarray(5))
assert v == 5
assert isinstance(v, np.int64)
assert v.shape == ()
with pytest.raises(TypeError):
scalar_from_tensor(vector())
with pytest.raises(TypeError):
scalar_from_tensor(vector())
class TestOpCache:
......@@ -2231,188 +2237,213 @@ class TestARange:
rng=rng,
)
def test_integers(self):
# Test arange constructor, on integer outputs
start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_integers(self, cast_policy):
"""Test arange constructor, on integer outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
if cast_policy == "custom":
assert out.dtype == "int64"
elif cast_policy == "numpy+floatX":
numpy_dtype = np.arange(np.array(1, dtype="int32")).dtype
assert out.dtype == numpy_dtype
assert np.all(f(0, 5, 1) == np.arange(0, 5, 1))
assert np.all(f(2, 11, 4) == np.arange(2, 11, 4))
assert np.all(f(-5, 1, 1) == np.arange(-5, 1, 1))
assert np.all(f(10, 2, -2) == np.arange(10, 2, -2))
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2))
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
numpy_dtype = np.arange(np.array(1, dtype="int32")).dtype
assert out.dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5, 1) == np.arange(0, 5, 1))
assert np.all(f(2, 11, 4) == np.arange(2, 11, 4))
assert np.all(f(-5, 1, 1) == np.arange(-5, 1, 1))
assert np.all(f(10, 2, -2) == np.arange(10, 2, -2))
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2))
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_float32(self, cast_policy):
"""Test arange constructor, on float32 outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = fscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
def test_float32(self):
# Test arange constructor, on float32 outputs
start, stop, step = fscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
if config.cast_policy == "custom":
assert out.dtype == start.type.dtype
elif config.cast_policy == "numpy+floatX":
assert out.dtype == config.floatX
arg_vals = [
(0, 5, 1),
(2, 11, 4),
(-5, 1.1, 1.2),
(1.3, 2, -2.1),
(10, 2, 2),
]
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom":
expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype
)
elif config.cast_policy == "numpy+floatX":
expected_val = np.arange(
start_v_, stop_v_, step_v_, dtype=out.dtype
)
assert np.all(f_val == expected_val)
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_float64(self, cast_policy):
"""Test arange constructor, on float64 outputs."""
with config.change_flags(cast_policy=cast_policy):
start, stop, step = dscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
if config.cast_policy == "custom":
assert out.dtype == start.type.dtype
elif config.cast_policy == "numpy":
numpy_dtype = np.arange(
np.array(0, dtype=start.dtype),
np.array(1, dtype=stop.dtype),
np.array(1, dtype=step.dtype),
).dtype
assert out.dtype == numpy_dtype
elif config.cast_policy == "numpy+floatX":
assert out.dtype == config.floatX
else:
raise NotImplementedError(config.cast_policy)
arg_vals = [(0, 5, 1), (2, 11, 4), (-5, 1.1, 1.2), (1.3, 2, -2.1), (10, 2, 2)]
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_)
arg_vals = [
(0, 5, 1),
(2, 11, 4),
(-5, 1.1, 1.2),
(1.3, 2, -2.1),
(10, 2, 2),
]
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom":
expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype
)
elif config.cast_policy == "numpy+floatX":
expected_val = np.arange(start_v_, stop_v_, step_v_)
assert np.all(f_val == expected_val)
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_default_step(self, cast_policy):
"""Test that arange constructor uses the correct default step."""
with config.change_flags(cast_policy=cast_policy):
start, stop = iscalars("start", "stop")
out = arange(start, stop)
f = function([start, stop], out)
if config.cast_policy == "custom":
expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype
)
elif config.cast_policy in ("numpy", "numpy+floatX"):
expected_val = np.arange(start_v_, stop_v_, step_v_, dtype=out.dtype)
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f_val == expected_val)
assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX":
assert out.dtype == np.arange(np.int32(0), np.int32(1)).dtype
def test_float64(self):
# Test arange constructor, on float64 outputs
start, stop, step = dscalars("start", "stop", "step")
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert np.all(f(0, 5) == np.arange(0, 5))
assert np.all(f(-5, 1) == np.arange(-5, 1))
assert np.all(f(0, 0) == np.arange(0, 0))
dstart, dstop = dscalars("start", "stop")
dout = arange(dstart, dstop)
df = function([dstart, dstop], dout)
assert dout.dtype == dstart.type.dtype
# print df(0.2, 5.3)
# print np.arange(0.2, 5.3)
assert np.all(df(0.2, 5.3) == np.arange(0.2, 5.3))
assert np.all(df(0.8, 5.3) == np.arange(0.8, 5.3))
assert np.all(df(-0.7, 5.3) == np.arange(-0.7, 5.3))
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_default_start(self, cast_policy):
"""Test that arange constructor uses the correct default start."""
with config.change_flags(cast_policy=cast_policy):
stop = iscalar("stop")
out = arange(stop)
f = function([stop], out)
assert out.dtype == start.type.dtype
arg_vals = [(0, 5, 1), (2, 11, 4), (-5, 1.1, 1.2), (1.3, 2, -2.1), (10, 2, 2)]
for arg_v in arg_vals:
start_v, stop_v, step_v = arg_v
start_v_, stop_v_, step_v_ = np.asarray(arg_v, dtype=start.type.dtype)
f_val = f(start_v_, stop_v_, step_v_)
if config.cast_policy == "custom":
expected_val = np.arange(
start_v, stop_v, step_v, dtype=start.type.dtype
)
elif config.cast_policy in ("numpy", "numpy+floatX"):
expected_val = np.arange(start_v_, stop_v_, step_v_)
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f_val == expected_val)
def test_default_step(self):
# Test that arange constructor uses the correct default step
start, stop = iscalars("start", "stop")
out = arange(start, stop)
f = function([start, stop], out)
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert out.dtype == np.arange(np.int32(0), np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5) == np.arange(0, 5))
assert np.all(f(-5, 1) == np.arange(-5, 1))
assert np.all(f(0, 0) == np.arange(0, 0))
dstart, dstop = dscalars("start", "stop")
dout = arange(dstart, dstop)
df = function([dstart, dstop], dout)
assert dout.dtype == dstart.type.dtype
# print df(0.2, 5.3)
# print np.arange(0.2, 5.3)
assert np.all(df(0.2, 5.3) == np.arange(0.2, 5.3))
assert np.all(df(0.8, 5.3) == np.arange(0.8, 5.3))
assert np.all(df(-0.7, 5.3) == np.arange(-0.7, 5.3))
def test_default_start(self):
# Test that arange constructor uses the correct default start
stop = iscalar("stop")
out = arange(stop)
f = function([stop], out)
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert out.dtype == np.arange(np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(8) == np.arange(8))
assert np.all(f(-2) == np.arange(-2))
fstop = fscalar("stop")
fout = arange(fstop)
ff = function([fstop], fout)
if config.cast_policy == "custom":
assert fout.dtype == fstop.type.dtype
elif config.cast_policy == "numpy":
assert fout.dtype == np.arange(np.float32(1)).dtype
elif config.cast_policy == "numpy+floatX":
if config.floatX == "float32":
assert fout.dtype == "float32"
else:
assert fout.dtype == np.arange(np.float32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
fstop_values = [0.2, -0.7, 8.5]
for fstop_v in fstop_values:
fstop_v32 = np.float32(fstop_v)
assert np.all(ff(fstop_v32) == np.arange(fstop_v))
def test_upcast(self):
# Test that arange computes output type adequately
if config.cast_policy == "custom":
assert arange(iscalar()).dtype == "int64"
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
elif config.cast_policy in ("numpy", "numpy+floatX"):
for dtype in get_numeric_types():
# Test with a single argument.
arange_dtype = arange(scalar(dtype=str(dtype))).dtype
numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype
if (
dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
):
# We want a float32 arange.
assert arange_dtype == "float32"
assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX":
assert out.dtype == np.arange(np.int32(1)).dtype
assert np.all(f(8) == np.arange(8))
assert np.all(f(-2) == np.arange(-2))
fstop = fscalar("stop")
fout = arange(fstop)
ff = function([fstop], fout)
if config.cast_policy == "custom":
assert fout.dtype == fstop.type.dtype
elif config.cast_policy == "numpy+floatX":
if config.floatX == "float32":
assert fout.dtype == "float32"
else:
# Follow numpy.
assert arange_dtype == numpy_dtype
# Test with two arguments.
for stop_dtype in get_numeric_types():
arange_dtype = arange(
start=scalar(dtype=str(dtype)),
stop=scalar(dtype=str(stop_dtype)),
).dtype
numpy_dtype = np.arange(
start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype),
).dtype
assert fout.dtype == np.arange(np.float32(1)).dtype
fstop_values = [0.2, -0.7, 8.5]
for fstop_v in fstop_values:
fstop_v32 = np.float32(fstop_v)
assert np.all(ff(fstop_v32) == np.arange(fstop_v))
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_upcast(self, cast_policy):
"""Test that arange computes output type adequately."""
with config.change_flags(cast_policy=cast_policy):
if config.cast_policy == "custom":
assert arange(iscalar()).dtype == "int64"
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
elif config.cast_policy == "numpy+floatX":
for dtype in get_numeric_types():
# Test with a single argument.
arange_dtype = arange(scalar(dtype=str(dtype))).dtype
numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype
if (
dtype != "float64"
and stop_dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
......@@ -2423,22 +2454,19 @@ class TestARange:
# Follow numpy.
assert arange_dtype == numpy_dtype
# Test with three arguments.
for step_dtype in get_numeric_types():
# Test with two arguments.
for stop_dtype in get_numeric_types():
arange_dtype = arange(
start=scalar(dtype=str(dtype)),
stop=scalar(dtype=str(stop_dtype)),
step=scalar(dtype=str(step_dtype)),
).dtype
numpy_dtype = np.arange(
start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype),
step=np.array(1, dtype=step_dtype),
).dtype
if (
dtype != "float64"
and stop_dtype != "float64"
and step_dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
......@@ -2448,8 +2476,32 @@ class TestARange:
else:
# Follow numpy.
assert arange_dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
# Test with three arguments.
for step_dtype in get_numeric_types():
arange_dtype = arange(
start=scalar(dtype=str(dtype)),
stop=scalar(dtype=str(stop_dtype)),
step=scalar(dtype=str(step_dtype)),
).dtype
numpy_dtype = np.arange(
start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype),
step=np.array(1, dtype=step_dtype),
).dtype
if (
dtype != "float64"
and stop_dtype != "float64"
and step_dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
):
# We want a float32 arange.
assert arange_dtype == "float32"
else:
# Follow numpy.
assert arange_dtype == numpy_dtype
def test_dtype_cache(self):
# Checks that the same Op is returned on repeated calls to arange
......@@ -2465,74 +2517,79 @@ class TestARange:
assert out2.owner.op is out3.owner.op
assert out3.owner.op is not out4.owner.op
def test_infer_shape(self):
start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step)
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
mode = compile.mode.get_mode(mode).excluding("fusion")
f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
numpy_dtype = np.arange(
np.array(0, dtype=start.dtype),
np.array(1, dtype=stop.dtype),
np.array(1, dtype=step.dtype),
).dtype
assert out.dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5, 1) == len(np.arange(0, 5, 1)))
assert np.all(f(2, 11, 4) == len(np.arange(2, 11, 4)))
assert np.all(f(-5, 1, 1) == len(np.arange(-5, 1, 1)))
assert np.all(f(10, 2, -2) == len(np.arange(10, 2, -2)))
assert np.all(f(10, 2, 2) == len(np.arange(10, 2, 2)))
assert np.all(f(0, 0, 1) == len(np.arange(0, 0, 1)))
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
assert out.dtype == np.arange(np.int32(0), np.int32(1), np.int32(1)).dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(0, 5) == len(np.arange(0, 5)))
assert np.all(f(2, 11) == len(np.arange(2, 11)))
assert np.all(f(-5, 1) == len(np.arange(-5, 1)))
assert np.all(f(10, 2) == len(np.arange(10, 2)))
assert np.all(f(10, 2) == len(np.arange(10, 2)))
assert np.all(f(0, 0) == len(np.arange(0, 0)))
assert np.all(f(-64, 64) == len(np.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128]
assert arange(-64, 64, 2).shape.eval() == [64]
out = arange(0, stop, 1)
f = function([stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 2
# [Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy in ("numpy", "numpy+floatX"):
numpy_dtype = np.arange(0, np.array(1, dtype=stop.dtype), 1).dtype
assert out.dtype == numpy_dtype
else:
raise NotImplementedError(config.cast_policy)
assert np.all(f(5) == len(np.arange(0, 5)))
assert np.all(f(11) == len(np.arange(0, 11)))
assert np.all(f(1) == len(np.arange(0, 1)))
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(0) == len(np.arange(0, 0)))
@pytest.mark.parametrize(
"cast_policy",
[
"custom",
"numpy+floatX",
],
)
def test_infer_shape(self, cast_policy):
with config.change_flags(cast_policy=cast_policy):
start, stop, step = iscalars("start", "stop", "step")
out = arange(start, stop, step)
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
mode = compile.mode.get_mode(mode).excluding("fusion")
f = function([start, stop, step], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 9
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX":
numpy_dtype = np.arange(
np.array(0, dtype=start.dtype),
np.array(1, dtype=stop.dtype),
np.array(1, dtype=step.dtype),
).dtype
assert out.dtype == numpy_dtype
assert np.all(f(0, 5, 1) == len(np.arange(0, 5, 1)))
assert np.all(f(2, 11, 4) == len(np.arange(2, 11, 4)))
assert np.all(f(-5, 1, 1) == len(np.arange(-5, 1, 1)))
assert np.all(f(10, 2, -2) == len(np.arange(10, 2, -2)))
assert np.all(f(10, 2, 2) == len(np.arange(10, 2, 2)))
assert np.all(f(0, 0, 1) == len(np.arange(0, 0, 1)))
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX":
assert (
out.dtype == np.arange(np.int32(0), np.int32(1), np.int32(1)).dtype
)
assert np.all(f(0, 5) == len(np.arange(0, 5)))
assert np.all(f(2, 11) == len(np.arange(2, 11)))
assert np.all(f(-5, 1) == len(np.arange(-5, 1)))
assert np.all(f(10, 2) == len(np.arange(10, 2)))
assert np.all(f(10, 2) == len(np.arange(10, 2)))
assert np.all(f(0, 0) == len(np.arange(0, 0)))
assert np.all(f(-64, 64) == len(np.arange(-64, 64)))
assert arange(-64, 64).shape.eval() == [128]
assert arange(-64, 64, 2).shape.eval() == [64]
out = arange(0, stop, 1)
f = function([stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 2
# [Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX":
numpy_dtype = np.arange(0, np.array(1, dtype=stop.dtype), 1).dtype
assert out.dtype == numpy_dtype
assert np.all(f(5) == len(np.arange(0, 5)))
assert np.all(f(11) == len(np.arange(0, 11)))
assert np.all(f(1) == len(np.arange(0, 1)))
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(2) == len(np.arange(0, 2)))
assert np.all(f(0) == len(np.arange(0, 0)))
class TestNdGrid:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论