提交 0f0e67ad authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba FillDiagonal: Do not mutate input

上级 d4a0433d
...@@ -100,16 +100,19 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -100,16 +100,19 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
def numba_funcify_FillDiagonal(op, **kwargs): def numba_funcify_FillDiagonal(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def filldiagonal(a, val): def filldiagonal(a, val):
a = a.copy()
np.fill_diagonal(a, val) np.fill_diagonal(a, val)
return a return a
return filldiagonal cache_version = 1
return filldiagonal, cache_version
@register_funcify_default_op_cache_key(FillDiagonalOffset) @register_funcify_default_op_cache_key(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs): def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def filldiagonaloffset(a, val, offset): def filldiagonaloffset(a, val, offset):
a = a.copy()
height, width = a.shape height, width = a.shape
offset_item = offset.item() offset_item = offset.item()
if offset >= 0: if offset >= 0:
...@@ -128,7 +131,8 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs): ...@@ -128,7 +131,8 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
# return a # return a
return b.reshape(a.shape) return b.reshape(a.shape)
return filldiagonaloffset cache_version = 1
return filldiagonaloffset, cache_version
@register_funcify_default_op_cache_key(RavelMultiIndex) @register_funcify_default_op_cache_key(RavelMultiIndex)
......
...@@ -84,7 +84,6 @@ def test_CumOp(val, axis, mode): ...@@ -84,7 +84,6 @@ def test_CumOp(val, axis, mode):
) )
@pytest.mark.xfail(reason="Implementation works inplace!")
def test_FillDiagonal(): def test_FillDiagonal():
a = pt.lmatrix("a") a = pt.lmatrix("a")
test_a = np.zeros((10, 2), dtype="int64") test_a = np.zeros((10, 2), dtype="int64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论