提交 2b15ce1a authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Replace sparse Neg Op with structured_elemwise function

上级 db215037
......@@ -25,7 +25,6 @@ from pytensor.sparse import (
GetItemListGrad,
GetItemScalar,
HStack,
Neg,
RowScaleCSC,
SparseFromDense,
Transpose,
......@@ -816,18 +815,6 @@ def numba_funcify_GetItemScalar(op, node, **kwargs):
return get_item_scalar_csc
@register_funcify_default_op_cache_key(Neg)
def numba_funcify_Neg(op, node, **kwargs):
@numba_basic.numba_njit
def neg(x):
z = x.copy()
z_data = z.data
z_data *= -1
return z
return neg
@register_funcify_default_op_cache_key(Diag)
def numba_funcify_Diag(op, node, **kwargs):
input_format = node.inputs[0].type.format
......
......@@ -1246,52 +1246,6 @@ class Transpose(Op):
transpose = Transpose()
class Neg(Op):
"""Negative of the sparse matrix (i.e. multiply by ``-1``).
Notes
-----
The grad is regular, i.e. not structured.
"""
__props__ = ()
def __str__(self):
return "Sparse" + self.__class__.__name__
def make_node(self, x):
"""
Parameters
----------
x
Sparse matrix.
"""
x = as_sparse_variable(x)
assert x.format in ("csr", "csc")
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outputs):
(x,) = inputs
(out,) = outputs
assert _is_sparse(x)
out[0] = -x
def grad(self, inputs, gout):
(x,) = inputs
(gz,) = gout
assert _is_sparse_variable(x) and _is_sparse_variable(gz)
return (-gz,)
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
neg = Neg()
class ColScaleCSC(Op):
# Scale each columns of a sparse matrix by the corresponding
# element of a dense vector
......
......@@ -237,6 +237,13 @@ def _conj(x):
"""
@structured_elemwise(ptm.neg)
def neg(x):
"""
Compute -x for all non-zero elements of x.
"""
def conjugate(x):
_x = psb.as_sparse_variable(x)
if _x.type.dtype not in complex_dtypes:
......
......@@ -11,7 +11,6 @@ from pytensor.sparse.basic import (
get_item_2lists,
get_item_list,
get_item_scalar,
neg,
sp_ones_like,
sp_zeros_like,
transpose,
......@@ -23,6 +22,7 @@ from pytensor.sparse.math import (
le,
lt,
multiply,
neg,
sp_sum,
structured_conjugate,
structured_dot,
......
......@@ -656,16 +656,6 @@ def test_sparse_get_item_scalar_wrong_index(format):
fn(x_test, 0, 5)
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_neg(format):
x = ps.matrix(format, name="x", shape=(7, 6), dtype=config.floatX)
z = -x
x_test = sp.sparse.random(7, 6, density=0.4, format=format, dtype=config.floatX)
compare_numba_and_py_sparse([x], z, [x_test])
@pytest.mark.parametrize("format", ("csr", "csc"))
def test_sparse_diag(format):
x = ps.matrix(format, name="x", shape=(8, 8), dtype=config.floatX)
......
......@@ -26,7 +26,6 @@ from pytensor.sparse.basic import (
EnsureSortedIndices,
GetItemScalar,
HStack,
Neg,
Remove0,
SparseFromDense,
SparseTensorType,
......@@ -409,15 +408,6 @@ class TestSparseInferShape(utt.InferShapeTester):
Transpose,
)
def test_neg(self):
x = SparseTensorType("csr", dtype=config.floatX)()
self._compile_and_check(
[x],
[-x],
[scipy_sparse.csr_matrix(random_lil((10, 40), config.floatX, 3))],
Neg,
)
def test_remove0(self):
x = SparseTensorType("csr", dtype=config.floatX)()
self._compile_and_check(
......
......@@ -1470,3 +1470,5 @@ SqrTester = elemwise_checker(psm.sqr, lambda x: x * x)
SqrtTester = elemwise_checker(psm.sqrt, np.sqrt, gap=(0, 10))
ConjTester = elemwise_checker(psm.conjugate, np.conj, grad_test=False)
NegTester = elemwise_checker(psm.neg, np.negative, name="TestNeg")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论