提交 00c9e1f7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make sure Numba's reshape gets C-ordered arrays

An extra case for zero-dimensional reshapes was added as well.
上级 3bde5122
...@@ -1126,9 +1126,17 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -1126,9 +1126,17 @@ def numba_funcify_Cast(op, node, **kwargs):
def numba_funcify_Reshape(op, **kwargs): def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim ndim = op.ndim
@numba.njit(inline="always") if ndim == 0:
def reshape(x, shape):
return np.reshape(x, to_fixed_tuple(shape, ndim)) @numba.njit(inline="always")
def reshape(x, shape):
return x.item()
else:
@numba.njit(inline="always")
def reshape(x, shape):
return np.reshape(np.ascontiguousarray(x), to_fixed_tuple(shape, ndim))
return reshape return reshape
......
...@@ -783,6 +783,7 @@ def test_Cast(v, dtype): ...@@ -783,6 +783,7 @@ def test_Cast(v, dtype):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, shape, ndim", "v, shape, ndim",
[ [
(set_test_value(aet.vector(), np.array([4], dtype=config.floatX)), (), 0),
(set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2), (set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2),
( (
set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)), set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论