提交 bacdaf65 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix AdvancedIncSubtensor1 C-compilation with empty indices

上级 12213d05
......@@ -2561,7 +2561,11 @@ class AdvancedIncSubtensor1(COp):
and y_.type.dtype not in complex_dtypes
):
# Simple implementation for vector x, y cases
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
idx_may_be_neg = not (
# Empty idx needs no negative checks
idx_.type.shape[0] == 0
or (isinstance(idx_, Constant) and idx_.data.min() >= 0)
)
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
shape0 = x_.type.shape[0]
# This is used to make sure that when we trust the indices to be valid
......@@ -2680,7 +2684,7 @@ class AdvancedIncSubtensor1(COp):
"""
def c_code_cache_version(self):
return (9,)
return (10,)
def _check_runtime_broadcasting(
self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray
......
......@@ -22,7 +22,7 @@ from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, get_vector_length, vectorize
from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf, lt, switch
......@@ -1730,7 +1730,7 @@ class TestIncSubtensor:
)
class TestIncSubtensor1:
class TestAdvancedIncSubtensor1:
def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed())
......@@ -1817,6 +1817,16 @@ class TestIncSubtensor1:
out1val, out2val = f(mval, incval, incval)
utt.assert_allclose(out1val, out2val)
def test_empty_index(self):
x = fvector()
idx = constant([], dtype="int64")
y = idx.astype("float32")
out = advanced_inc_subtensor1(x, y, idx)
test_x = np.array([1, 2, 3], dtype="float32")
res = out.eval({x: test_x}, mode=Mode(optimizer=None))
np.testing.assert_array_equal(res, test_x)
class TestAdvancedSubtensor:
"""Test inc_subtensor and set_subtensor."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论