提交 d351b09d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix incorrect dtypes in LUFactor and PivotToPremutations

上级 4378d482
...@@ -83,7 +83,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -83,7 +83,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
@numba_funcify.register(PivotToPermutations) @numba_funcify.register(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs): def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse inverse = op.inverse
dtype = node.inputs[0].dtype dtype = node.outputs[0].dtype
@numba_njit @numba_njit
def numba_pivot_to_permutation(piv): def numba_pivot_to_permutation(piv):
......
...@@ -604,7 +604,7 @@ class PivotToPermutations(Op): ...@@ -604,7 +604,7 @@ class PivotToPermutations(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
[pivots] = inputs [pivots] = inputs
p_inv = np.arange(len(pivots), dtype=pivots.dtype) p_inv = np.arange(len(pivots), dtype="int64")
for i in range(len(pivots)): for i in range(len(pivots)):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
...@@ -639,7 +639,7 @@ class LUFactor(Op): ...@@ -639,7 +639,7 @@ class LUFactor(Op):
) )
LU = matrix(shape=A.type.shape, dtype=A.type.dtype) LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
pivots = vector(shape=(A.type.shape[0],), dtype="int64") pivots = vector(shape=(A.type.shape[0],), dtype="int32")
return Apply(self, [A], [LU, pivots]) return Apply(self, [A], [LU, pivots])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论