提交 ac11da62 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba DimShuffle: special case for 0d input

This circumvents a bug when DimShuffle of a scalar shows up inside a Blockwise, as the outer indexing yields a float (as opposed to a numpy scalar) which has no `.shape` attribute.
上级 31304be2
...@@ -466,6 +466,16 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs): ...@@ -466,6 +466,16 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
return squeeze_to_0d return squeeze_to_0d
elif op.input_ndim == 0:
# DimShuffle can only be an expand_dims or a no_op
# This branch uses asarray in case we get a scalar due to https://github.com/numba/numba/issues/10358
new_shape = shape_template
new_strides = strides_template
@numba_basic.numba_njit
def dimshuffle(x):
return as_strided(np.asarray(x), shape=new_shape, strides=new_strides)
else: else:
@numba_basic.numba_njit @numba_basic.numba_njit
...@@ -490,7 +500,7 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs): ...@@ -490,7 +500,7 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
return as_strided(x, shape=new_shape, strides=new_strides) return as_strided(x, shape=new_shape, strides=new_strides)
cache_version = 1 cache_version = 2
return dimshuffle, cache_version return dimshuffle, cache_version
......
...@@ -5,6 +5,7 @@ from pytensor import function ...@@ -5,6 +5,7 @@ from pytensor import function
from pytensor.tensor import lvector, tensor, tensor3 from pytensor.tensor import lvector, tensor, tensor3
from pytensor.tensor.basic import Alloc, ARange, constant from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky from pytensor.tensor.slinalg import Cholesky, cholesky
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
...@@ -80,3 +81,12 @@ def test_blockwise_alloc(): ...@@ -80,3 +81,12 @@ def test_blockwise_alloc():
assert out.type.ndim == 3 assert out.type.ndim == 3
compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False) compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)
def test_blockwise_scalar_dimshuffle():
x = lvector("x")
blockwise_scalar_ds = Blockwise(
DimShuffle(input_ndim=0, new_order=["x", "x"]), signature="()->(1,1)"
)
out = blockwise_scalar_ds(x)
compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论