提交 09cdd75a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

PivotToPermutation: Stick with default int64 behavior

上级 2653ddea
......@@ -12,8 +12,8 @@ from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
@numba_basic.numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)
def _pivot_to_permutation(p):
p_inv = np.arange(len(p))
for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv
......@@ -29,7 +29,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
p_inv = _pivot_to_permutation(IPIV)
perm = np.argsort(p_inv).astype("int32")
return perm, L, U
......
......@@ -97,11 +97,10 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@register_funcify_default_op_cache_key(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.outputs[0].dtype
@numba_basic.numba_njit
def numba_pivot_to_permutation(piv):
p_inv = _pivot_to_permutation(piv, dtype)
p_inv = _pivot_to_permutation(piv)
if inverse:
return p_inv
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论