提交 0f0e67ad authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba FillDiagonal: Do not mutate input

上级 d4a0433d
......@@ -100,16 +100,19 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
def numba_funcify_FillDiagonal(op, **kwargs):
@numba_basic.numba_njit
def filldiagonal(a, val):
a = a.copy()
np.fill_diagonal(a, val)
return a
return filldiagonal
cache_version = 1
return filldiagonal, cache_version
@register_funcify_default_op_cache_key(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba_basic.numba_njit
def filldiagonaloffset(a, val, offset):
a = a.copy()
height, width = a.shape
offset_item = offset.item()
if offset >= 0:
......@@ -128,7 +131,8 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
# return a
return b.reshape(a.shape)
return filldiagonaloffset
cache_version = 1
return filldiagonaloffset, cache_version
@register_funcify_default_op_cache_key(RavelMultiIndex)
......
......@@ -84,7 +84,6 @@ def test_CumOp(val, axis, mode):
)
@pytest.mark.xfail(reason="Implementation works inplace!")
def test_FillDiagonal():
a = pt.lmatrix("a")
test_a = np.zeros((10, 2), dtype="int64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论