提交 8267d0e4 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

More robust check for multiple integer indices in numba ravel_multidimensional_idx rewrites

上级 4e85676a
...@@ -85,7 +85,7 @@ from pytensor.tensor.subtensor import ( ...@@ -85,7 +85,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor, inc_subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node): ...@@ -1981,7 +1981,7 @@ def ravel_multidimensional_bool_idx(fgraph, node):
if any( if any(
( (
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")) (isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes)
or isinstance(idx.type, NoneTypeT) or isinstance(idx.type, NoneTypeT)
) )
for idx in idxs for idx in idxs
...@@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node): ...@@ -2052,7 +2052,7 @@ def ravel_multidimensional_int_idx(fgraph, node):
int_idxs = [ int_idxs = [
(i, idx) (i, idx)
for i, idx in enumerate(idxs) for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int")) if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
] ]
if len(int_idxs) != 1: if len(int_idxs) != 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论