提交 3ff76039 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba DimShuffle: validate squeeze

上级 4298b761
...@@ -397,10 +397,11 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -397,10 +397,11 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@register_funcify_default_op_cache_key(DimShuffle) @register_funcify_default_op_cache_key(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs): def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call # We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays. # Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
new_order = tuple(op._new_order) new_order = tuple(op._new_order)
drop = tuple(op.drop)
shape_template = (1,) * node.outputs[0].ndim shape_template = (1,) * node.outputs[0].ndim
strides_template = (0,) * node.outputs[0].ndim strides_template = (0,) * node.outputs[0].ndim
...@@ -409,6 +410,11 @@ def numba_funcify_DimShuffle(op, node, **kwargs): ...@@ -409,6 +410,11 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def squeeze_to_0d(x): def squeeze_to_0d(x):
if not x.size == 1:
raise ValueError(
"DimShuffle: Attempting to squeeze axes with size not equal to one"
)
assert x.size == 1
return as_strided(x, shape=(), strides=()) return as_strided(x, shape=(), strides=())
return squeeze_to_0d return squeeze_to_0d
...@@ -428,10 +434,17 @@ def numba_funcify_DimShuffle(op, node, **kwargs): ...@@ -428,10 +434,17 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
new_strides = numba_basic.tuple_setitem( new_strides = numba_basic.tuple_setitem(
new_strides, i, old_strides[o] new_strides, i, old_strides[o]
) )
if drop:
for dropped_dim in drop:
if old_shape[dropped_dim] != 1:
raise ValueError(
"DimShuffle: Attempting to squeeze axes with size not equal to one"
)
return as_strided(x, shape=new_shape, strides=new_strides) return as_strided(x, shape=new_shape, strides=new_strides)
return dimshuffle cache_version = 1
return dimshuffle, cache_version
@register_funcify_default_op_cache_key(Softmax) @register_funcify_default_op_cache_key(Softmax)
......
...@@ -19,6 +19,7 @@ from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum ...@@ -19,6 +19,7 @@ from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
numba_mode,
scalar_my_multi_out, scalar_my_multi_out,
) )
from tests.tensor.test_elemwise import ( from tests.tensor.test_elemwise import (
...@@ -217,6 +218,17 @@ def test_Dimshuffle_non_contiguous(): ...@@ -217,6 +218,17 @@ def test_Dimshuffle_non_contiguous():
assert func(np.zeros(3), np.array([1])).ndim == 0 assert func(np.zeros(3), np.array([1])).ndim == 0
def test_Dimshuffle_squeeze_errors():
x = pt.tensor3("x", shape=(4, None, 5))
out = pt.squeeze(x, axis=1)
assert out.type.shape == (4, 5)
fn = function([x], out, mode=numba_mode)
with pytest.raises(
ValueError, match="Attempting to squeeze axes with size not equal to one"
):
fn(np.zeros((4, 2, 5)))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"careduce_fn, axis, v", "careduce_fn, axis, v",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论