提交 d3a435a9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parametrize test_indexed_result_shape

上级 bc0deb3f
...@@ -2463,27 +2463,69 @@ def test_basic_shape(): ...@@ -2463,27 +2463,69 @@ def test_basic_shape():
assert get_test_value(res) == (2,) assert get_test_value(res) == (2,)
@config.change_flags(compute_test_value="raise") def idx_as_tensor(x):
def test_indexed_result_shape():
_test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
test_shape = (5, 6, 7, 8)
test_array = np.arange(np.prod(test_shape)).reshape(test_shape)
def idx_as_tensor(x):
if isinstance(x, (slice, type(None))): if isinstance(x, (slice, type(None))):
return x return x
else: else:
return at.as_tensor(x) return at.as_tensor(x)
def bcast_shape_tuple(x):
def bcast_shape_tuple(x):
if not hasattr(x, "shape"): if not hasattr(x, "shape"):
return x return x
return tuple( return tuple(
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable) s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
) )
def compare_index_shapes(test_array, test_idx):
test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
@pytest.mark.parametrize(
"test_array, test_idx",
[
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), (slice(None, None),)),
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), (2,)),
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:1]),
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:2]),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:2] + (slice(None, None),),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None),) + test_idx[:1],
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None) + test_idx[1:2],
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(np.array(1), slice(None, None), None),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None, np.array(1)),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (slice(None, None),) + test_idx[1:2],
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (slice(None, None),) + test_idx[1:2] + (slice(None, None),),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (None,) + test_idx[1:2],
),
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
],
)
@config.change_flags(compute_test_value="raise")
def test_indexed_result_shape(test_array, test_idx):
res = indexed_result_shape( res = indexed_result_shape(
at.as_tensor(test_array).shape, [idx_as_tensor(i) for i in test_idx] at.as_tensor(test_array).shape, [idx_as_tensor(i) for i in test_idx]
) )
...@@ -2499,55 +2541,6 @@ def test_indexed_result_shape(): ...@@ -2499,55 +2541,6 @@ def test_indexed_result_shape():
exp_res = test_array[test_idx].shape exp_res = test_array[test_idx].shape
assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res) assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res)
# Simple basic indices
test_idx = (slice(None, None),)
compare_index_shapes(test_array, test_idx)
# Advanced indices
test_idx = (2,)
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1]
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:2]
compare_index_shapes(test_array, test_idx)
# A Mix of advanced and basic indices
test_idx = _test_idx[:2] + (slice(None, None),)
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None),) + _test_idx[1:]
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None), None) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_idx = (np.array(1), slice(None, None), None)
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None), None, np.array(1))
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1] + (slice(None, None),) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_idx = (
_test_idx[:1] + (slice(None, None),) + _test_idx[1:2] + (slice(None, None),)
)
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1] + (None,) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_shape = (5, 4)
test_array = np.arange(np.prod(test_shape)).reshape(test_shape)
test_idx = ([1, 3, 2], slice(1, 3))
compare_index_shapes(test_array, test_idx)
test_idx = (slice(1, 3), [1, 3, 2])
compare_index_shapes(test_array, test_idx)
def test_symbolic_slice(): def test_symbolic_slice():
x = tensor4("x") x = tensor4("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论