提交 8267d0e4 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

More robust check for multiple integer indices in numba ravel_multidimensional_idx rewrites

上级 4e85676a
......@@ -85,7 +85,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):
if any(
(
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
(isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
or isinstance(idx.type, NoneTypeT)
)
for idx in idxs
......@@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node):
int_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]
if len(int_idxs) != 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论