提交 09cdd75a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

PivotToPermutation: Stick with default int64 behavior

上级 2653ddea
...@@ -12,8 +12,8 @@ from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix ...@@ -12,8 +12,8 @@ from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
@numba_basic.numba_njit @numba_basic.numba_njit
def _pivot_to_permutation(p, dtype): def _pivot_to_permutation(p):
p_inv = np.arange(len(p)).astype(dtype) p_inv = np.arange(len(p))
for i in range(len(p)): for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i] p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv return p_inv
...@@ -29,7 +29,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): ...@@ -29,7 +29,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1 IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype) p_inv = _pivot_to_permutation(IPIV)
perm = np.argsort(p_inv).astype("int32") perm = np.argsort(p_inv).astype("int32")
return perm, L, U return perm, L, U
......
...@@ -97,11 +97,10 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -97,11 +97,10 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@register_funcify_default_op_cache_key(PivotToPermutations) @register_funcify_default_op_cache_key(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs): def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse inverse = op.inverse
dtype = node.outputs[0].dtype
@numba_basic.numba_njit @numba_basic.numba_njit
def numba_pivot_to_permutation(piv): def numba_pivot_to_permutation(piv):
p_inv = _pivot_to_permutation(piv, dtype) p_inv = _pivot_to_permutation(piv)
if inverse: if inverse:
return p_inv return p_inv
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论