Unverified 提交 bfd72571 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Fix non-contiguous reshapes in numba backend (#255)

上级 0c9eb9ca
......@@ -841,7 +841,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
return np.reshape(x, ())
return np.reshape(np.ascontiguousarray(x), ())
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
......
......@@ -218,6 +218,17 @@ def test_Dimshuffle_returns_array():
assert out.ndim == 0
def test_Dimshuffle_non_contiguous():
"""The numba impl of reshape doesn't work with
non-contiguous arrays, make sure we work around that."""
x = at.dvector()
idx = at.vector(dtype="int64")
op = pytensor.tensor.elemwise.DimShuffle([True], [])
out = op(at.specify_shape(x[idx][::2], (1,)))
func = pytensor.function([x, idx], out, mode="NUMBA")
assert func(np.zeros(3), np.array([1])).ndim == 0
@pytest.mark.parametrize(
"careduce_fn, axis, v",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论