提交 d351b09d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix incorrect dtypes in LUFactor and PivotToPremutations

上级 4378d482
......@@ -83,7 +83,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@numba_funcify.register(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.inputs[0].dtype
dtype = node.outputs[0].dtype
@numba_njit
def numba_pivot_to_permutation(piv):
......
......@@ -604,7 +604,7 @@ class PivotToPermutations(Op):
def perform(self, node, inputs, outputs):
[pivots] = inputs
p_inv = np.arange(len(pivots), dtype=pivots.dtype)
p_inv = np.arange(len(pivots), dtype="int64")
for i in range(len(pivots)):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
......@@ -639,7 +639,7 @@ class LUFactor(Op):
)
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
pivots = vector(shape=(A.type.shape[0],), dtype="int64")
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
return Apply(self, [A], [LU, pivots])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论