Unverified 提交 3082ed5e authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Rename sparse functions to match numpy array API (#1663)

* Rename `mul` -> `multiply` * Rename `sub` -> `subtract` * Space... the final frontier
上级 5547eb08
......@@ -2268,36 +2268,49 @@ def add(x, y):
raise NotImplementedError()
def sub(x, y):
def subtract(
x: SparseVariable | TensorVariable, y: SparseVariable | TensorVariable
) -> SparseVariable:
"""
Subtract two matrices, at least one of which is sparse.
This method will provide the right op according
to the inputs.
This method will provide the right op according to the inputs.
Parameters
----------
x
x : SparseVariable or TensorVariable
A matrix variable.
y
y : SparseVariable or TensorVariable
A matrix variable.
Returns
-------
A sparse matrix
`x` - `y`
result: SparseVariable
Result of `x - y`, as a sparse matrix.
Notes
-----
At least one of `x` and `y` must be a sparse matrix.
The grad will be structured only when one of the variable will be a dense
matrix.
The grad will be structured only when one of the variable will be a dense matrix.
"""
return x + (-y)
def sub(x, y):
warn(
"pytensor.sparse.sub is deprecated and will be removed in a future version. Use "
"pytensor.sparse.subtract instead.",
category=DeprecationWarning,
stacklevel=2,
)
return subtract(x, y)
sub.__doc__ = subtract.__doc__
class MulSS(Op):
# mul(sparse, sparse)
# See the doc of mul() for more detail
......@@ -2491,7 +2504,9 @@ class MulSV(Op):
mul_s_v = MulSV()
def mul(x, y):
def multiply(
x: SparseTensorType | TensorType, y: SparseTensorType | TensorType
) -> SparseVariable:
"""
Multiply elementwise two matrices, at least one of which is sparse.
......@@ -2499,21 +2514,21 @@ def mul(x, y):
Parameters
----------
x
x : SparseVariable
A matrix variable.
y
y : SparseVariable
A matrix variable.
Returns
-------
A sparse matrix
`x` * `y`
result: SparseVariable
The elementwise multiplication of `x` and `y`.
Notes
-----
At least one of `x` and `y` must be a sparse matrix.
The grad is regular, i.e. not structured.
The gradient is regular, i.e. not structured.
"""
x = as_sparse_or_tensor_variable(x)
......@@ -2541,6 +2556,20 @@ def mul(x, y):
raise NotImplementedError()
def mul(x, y):
warn(
"pytensor.sparse.mul is deprecated and will be removed in a future version. Use "
"pytensor.sparse.multiply instead.",
category=DeprecationWarning,
stacklevel=2,
)
return multiply(x, y)
mul.__doc__ = multiply.__doc__
class __ComparisonOpSS(Op):
"""
Used as a superclass for all comparisons between two sparses matrices.
......
......@@ -65,8 +65,8 @@ from pytensor.sparse import (
gt,
le,
lt,
mul,
mul_s_v,
multiply,
sampling_dot,
sp_ones_like,
square_diagonal,
......@@ -724,21 +724,21 @@ class TestAddMul:
def test_MulSS(self):
self._testSS(
mul,
multiply,
np.array([[1.0, 0], [3, 0], [0, 6]]),
np.array([[1.0, 2], [3, 0], [0, 6]]),
)
def test_MulSD(self):
self._testSD(
mul,
multiply,
np.array([[1.0, 0], [3, 0], [0, 6]]),
np.array([[1.0, 2], [3, 0], [0, 6]]),
)
def test_MulDS(self):
self._testDS(
mul,
multiply,
np.array([[1.0, 0], [3, 0], [0, 6]]),
np.array([[1.0, 2], [3, 0], [0, 6]]),
)
......@@ -783,7 +783,7 @@ class TestAddMul:
assert np.all(val.todense() == array1 + array2)
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=False)
elif op is mul:
elif op is multiply:
assert np.all(val.todense() == array1 * array2)
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=False)
......@@ -833,7 +833,7 @@ class TestAddMul:
continue
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=True)
elif op is mul:
elif op is multiply:
assert _is_sparse_variable(apb)
assert np.all(val.todense() == b.multiply(array1))
assert np.all(
......@@ -887,7 +887,7 @@ class TestAddMul:
b = b.data
if dtype1.startswith("float") and dtype2.startswith("float"):
verify_grad_sparse(op, [a, b], structured=True)
elif op is mul:
elif op is multiply:
assert _is_sparse_variable(apb)
ans = np.array([[1, 0], [9, 0], [0, 36]])
assert np.all(val.todense() == (a.multiply(array2)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论