提交 fb35c37d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make Reshape work for Numba scalars

上级 90021636
...@@ -1154,7 +1154,10 @@ def numba_funcify_Reshape(op, **kwargs): ...@@ -1154,7 +1154,10 @@ def numba_funcify_Reshape(op, **kwargs):
@numba.njit(inline="always") @numba.njit(inline="always")
def reshape(x, shape): def reshape(x, shape):
return np.reshape(np.ascontiguousarray(x), to_fixed_tuple(shape, ndim)) # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
np.ascontiguousarray(np.asarray(x)), to_fixed_tuple(shape, ndim)
)
return reshape return reshape
......
...@@ -817,6 +817,21 @@ def test_Reshape(v, shape, ndim): ...@@ -817,6 +817,21 @@ def test_Reshape(v, shape, ndim):
) )
def test_Reshape_scalar():
v = aet.vector()
v.tag.test_value = np.array([1.0], dtype=config.floatX)
g = Reshape(1)(v[0], (1,))
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, shape, fails", "v, shape, fails",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论