提交 a332c8b2 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement Neg sparse Op in Numba backend

上级 ab7a4bdd
......@@ -24,6 +24,7 @@ from pytensor.sparse import (
GetItemListGrad,
GetItemScalar,
HStack,
Neg,
RowScaleCSC,
SparseFromDense,
Transpose,
......@@ -812,3 +813,15 @@ def numba_funcify_GetItemScalar(op, node, **kwargs):
return np.asarray(out, dtype=out_dtype)
return get_item_scalar_csc
@register_funcify_default_op_cache_key(Neg)
def numba_funcify_Neg(op, node, **kwargs):
@numba_basic.numba_njit
def neg(x):
z = x.copy()
z_data = z.data
z_data *= -1
return z
return neg
......@@ -654,3 +654,13 @@ def test_sparse_get_item_scalar_wrong_index(format):
with pytest.raises(IndexError, match="column index out of bounds"):
fn(x_test, 0, 5)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_neg(format):
x = ps.matrix(format, name="x", shape=(7, 6), dtype=config.floatX)
z = -x
x_test = sp.sparse.random(7, 6, density=0.4, format=format, dtype=config.floatX)
compare_numba_and_py_sparse([x], z, [x_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论