Unverified 提交 4730d0cb authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Fix numba impl of empty DimShuffle (#218)

Empty DimShuffles would return a scalar instead of an array, which can then lead to errors in ops that expect an array.
上级 8606498f
......@@ -841,7 +841,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
return x.item()
return np.reshape(x, ())
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
......
......@@ -210,6 +210,14 @@ def test_Dimshuffle(v, new_order):
)
def test_Dimshuffle_returns_array():
x = at.vector("x", shape=(1,))
y = 2 * at_elemwise.DimShuffle([True], [])(x)
func = pytensor.function([x], y, mode="NUMBA")
out = func(np.zeros(1, dtype=config.floatX))
assert out.ndim == 0
@pytest.mark.parametrize(
"careduce_fn, axis, v",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论