提交 bacdaf65 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix AdvancedIncSubtensor1 C-compilation with empty indices

上级 12213d05
...@@ -2561,7 +2561,11 @@ class AdvancedIncSubtensor1(COp): ...@@ -2561,7 +2561,11 @@ class AdvancedIncSubtensor1(COp):
and y_.type.dtype not in complex_dtypes and y_.type.dtype not in complex_dtypes
): ):
# Simple implementation for vector x, y cases # Simple implementation for vector x, y cases
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0) idx_may_be_neg = not (
# Empty idx needs no negative checks
idx_.type.shape[0] == 0
or (isinstance(idx_, Constant) and idx_.data.min() >= 0)
)
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_) idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
shape0 = x_.type.shape[0] shape0 = x_.type.shape[0]
# This is used to make sure that when we trust the indices to be valid # This is used to make sure that when we trust the indices to be valid
...@@ -2680,7 +2684,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2680,7 +2684,7 @@ class AdvancedIncSubtensor1(COp):
""" """
def c_code_cache_version(self): def c_code_cache_version(self):
return (9,) return (10,)
def _check_runtime_broadcasting( def _check_runtime_broadcasting(
self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray
......
...@@ -22,7 +22,7 @@ from pytensor.graph.op import get_test_value ...@@ -22,7 +22,7 @@ from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar, int16 from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, get_vector_length, vectorize from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf, lt, switch from pytensor.tensor.math import exp, isinf, lt, switch
...@@ -1730,7 +1730,7 @@ class TestIncSubtensor: ...@@ -1730,7 +1730,7 @@ class TestIncSubtensor:
) )
class TestIncSubtensor1: class TestAdvancedIncSubtensor1:
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed()) self.rng = np.random.default_rng(seed=utt.fetch_seed())
...@@ -1817,6 +1817,16 @@ class TestIncSubtensor1: ...@@ -1817,6 +1817,16 @@ class TestIncSubtensor1:
out1val, out2val = f(mval, incval, incval) out1val, out2val = f(mval, incval, incval)
utt.assert_allclose(out1val, out2val) utt.assert_allclose(out1val, out2val)
def test_empty_index(self):
x = fvector()
idx = constant([], dtype="int64")
y = idx.astype("float32")
out = advanced_inc_subtensor1(x, y, idx)
test_x = np.array([1, 2, 3], dtype="float32")
res = out.eval({x: test_x}, mode=Mode(optimizer=None))
np.testing.assert_array_equal(res, test_x)
class TestAdvancedSubtensor: class TestAdvancedSubtensor:
"""Test inc_subtensor and set_subtensor.""" """Test inc_subtensor and set_subtensor."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论