提交 3ff76039 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba DimShuffle: validate squeeze

上级 4298b761
......@@ -397,10 +397,11 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@register_funcify_default_op_cache_key(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs):
def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
new_order = tuple(op._new_order)
drop = tuple(op.drop)
shape_template = (1,) * node.outputs[0].ndim
strides_template = (0,) * node.outputs[0].ndim
......@@ -409,6 +410,11 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
@numba_basic.numba_njit
def squeeze_to_0d(x):
if not x.size == 1:
raise ValueError(
"DimShuffle: Attempting to squeeze axes with size not equal to one"
)
assert x.size == 1
return as_strided(x, shape=(), strides=())
return squeeze_to_0d
......@@ -428,10 +434,17 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
new_strides = numba_basic.tuple_setitem(
new_strides, i, old_strides[o]
)
if drop:
for dropped_dim in drop:
if old_shape[dropped_dim] != 1:
raise ValueError(
"DimShuffle: Attempting to squeeze axes with size not equal to one"
)
return as_strided(x, shape=new_shape, strides=new_strides)
return dimshuffle
cache_version = 1
return dimshuffle, cache_version
@register_funcify_default_op_cache_key(Softmax)
......
......@@ -19,6 +19,7 @@ from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
compare_numba_and_py,
numba_mode,
scalar_my_multi_out,
)
from tests.tensor.test_elemwise import (
......@@ -217,6 +218,17 @@ def test_Dimshuffle_non_contiguous():
assert func(np.zeros(3), np.array([1])).ndim == 0
def test_Dimshuffle_squeeze_errors():
x = pt.tensor3("x", shape=(4, None, 5))
out = pt.squeeze(x, axis=1)
assert out.type.shape == (4, 5)
fn = function([x], out, mode=numba_mode)
with pytest.raises(
ValueError, match="Attempting to squeeze axes with size not equal to one"
):
fn(np.zeros((4, 2, 5)))
@pytest.mark.parametrize(
"careduce_fn, axis, v",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论