提交 7826c272 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed behaviour of at.subtensor.take when axis is None

上级 70c48957
...@@ -2770,10 +2770,8 @@ def take(a, indices, axis=None, mode="raise"): ...@@ -2770,10 +2770,8 @@ def take(a, indices, axis=None, mode="raise"):
if not isinstance(axis, (int, type(None))): if not isinstance(axis, (int, type(None))):
raise TypeError("`axis` must be an integer or None") raise TypeError("`axis` must be an integer or None")
if axis is None and indices.ndim == 1: if axis is None:
return advanced_subtensor1(a.flatten(), indices) return advanced_subtensor(a.flatten(), indices)
elif axis == 0 and indices.ndim == 1:
return advanced_subtensor1(a, indices)
elif axis < 0: elif axis < 0:
axis += a.ndim axis += a.ndim
......
...@@ -1403,33 +1403,24 @@ def test_take_basic(): ...@@ -1403,33 +1403,24 @@ def test_take_basic():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"a, index, axis, mode", "a_shape, index, axis, mode",
[ [
(matrix(), lvector(), -1, None), ((4, 5, 6), np.array([[1, 2, 3], [1, 2, 3]]), -1, None),
(matrix(), lvector(), 0, None), ((4, 5, 6), np.array([[1, 2, 3], [5, 6, 7]]), None, None),
(matrix(), lvector(), 1, None), ((4, 5, 6), np.array([[1, 2, 3], [1, 2, 3]]), 1, None),
(matrix(), lvector(), 1, "clip"), ((4, 5, 6), np.array([[1, 2, 3], [5, 6, 7]]), 1, "clip"),
(matrix(), lvector(), 1, "wrap"), ((4, 5, 6), np.array([[1, 2, 3], [5, 6, 7]]), 1, "wrap"),
], ],
) )
def test_take_cases(a, index, axis, mode): def test_take_cases(a_shape, index, axis, mode):
fn_mode = aesara.compile.mode.get_default_mode() a_val = np.random.random(size=a_shape).astype(config.floatX)
fn_mode = fn_mode.including( py_res = a_val.take(index, axis=axis, mode=mode)
"local_useless_subtensor",
"local_replace_AdvancedSubtensor",
)
f = aesara.function([a, index], a.take(index, axis=axis, mode=mode), mode=fn_mode)
a_val = np.arange(3 * 3).reshape((3, 3)).astype(config.floatX)
if mode is None: a = at.as_tensor_variable(a_val)
index_val = np.array([0, 1], dtype=np.int64) index = at.as_tensor_variable(index)
else:
index_val = np.array([-1, 2], dtype=np.int64)
py_res = a_val.take(index_val, axis=axis, mode=mode) f = aesara.function([], a.take(index, axis=axis, mode=mode))
f_res = f(a_val, index_val) f_res = f()
assert np.array_equal(py_res, f_res) assert np.array_equal(py_res, f_res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论