提交 528ee54f authored 作者: Anatoly's avatar Anatoly 提交者: Brandon T. Willard

add sum method to _sparse_py_operators

上级 b088cc8f
...@@ -299,6 +299,9 @@ class _sparse_py_operators: ...@@ -299,6 +299,9 @@ class _sparse_py_operators:
def __rdot__(right, left): def __rdot__(right, left):
return structured_dot(left, right) return structured_dot(left, right)
def sum(self, axis=None, sparse_grad=False):
return sp_sum(self, axis=axis, sparse_grad=sparse_grad)
dot = __dot__ dot = __dot__
# N.B. THIS IS COMMENTED OUT ON PURPOSE!!! # N.B. THIS IS COMMENTED OUT ON PURPOSE!!!
......
...@@ -2039,12 +2039,17 @@ class TestSpSum(utt.InferShapeTester): ...@@ -2039,12 +2039,17 @@ class TestSpSum(utt.InferShapeTester):
self.op_class = sparse.SpSum self.op_class = sparse.SpSum
self.op = sparse.sp_sum self.op = sparse.sp_sum
def test_op(self): @pytest.mark.parametrize("op_type", ["func", "method"])
def test_op(self, op_type):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
for axis in self.possible_axis: for axis in self.possible_axis:
variable, data = sparse_random_inputs(format, shape=(10, 10)) variable, data = sparse_random_inputs(format, shape=(10, 10))
z = sparse.sp_sum(variable[0], axis=axis) if op_type == "func":
z = sparse.sp_sum(variable[0], axis=axis)
if op_type == "method":
z = variable[0].sum(axis=axis)
if axis is None: if axis is None:
assert z.type.broadcastable == () assert z.type.broadcastable == ()
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论