提交 ae9b8c2f authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix how a function is accessed and add missing test for that case.

上级 04eb6bf9
...@@ -55,7 +55,7 @@ class SpSum(Op): ...@@ -55,7 +55,7 @@ class SpSum(Op):
z = tensor.tensor(broadcastable=(), dtype=x.dtype) z = tensor.tensor(broadcastable=(), dtype=x.dtype)
elif self.axis == 0: elif self.axis == 0:
if x.format == 'csc': if x.format == 'csc':
z = T.tensor(broadcastable=(False,), dtype=x.dtype) z = tensor.tensor(broadcastable=(False,), dtype=x.dtype)
elif x.format == 'csr': elif x.format == 'csr':
#return SparseVector() #WRITEME! #return SparseVector() #WRITEME!
raise NotImplementedError() raise NotImplementedError()
...@@ -66,7 +66,7 @@ class SpSum(Op): ...@@ -66,7 +66,7 @@ class SpSum(Op):
#return SparseVector() #WRITEME! #return SparseVector() #WRITEME!
raise NotImplementedError() raise NotImplementedError()
elif x.format == 'csr': elif x.format == 'csr':
z = T.tensor(broadcastable=(False,), dtype=x.dtype) z = tensor.tensor(broadcastable=(False,), dtype=x.dtype)
else: else:
raise NotImplementedError() raise NotImplementedError()
else: else:
......
...@@ -362,6 +362,44 @@ class TestSP(unittest.TestCase): ...@@ -362,6 +362,44 @@ class TestSP(unittest.TestCase):
# symbolic stuff # symbolic stuff
utt.verify_grad(d, [kvals]) utt.verify_grad(d, [kvals])
def test_sp_sum(self):
# TODO: test both grad.
for format,cast in [("csc",scipy.sparse.csc_matrix), ("csr",scipy.sparse.csr_matrix)]:
x = theano.sparse.SparseType(format=format,
dtype=theano.config.floatX)()
x_data = numpy.arange(20).reshape(5,4).astype(theano.config.floatX)
# Sum on all axis
z = theano.sparse.sandbox.sp.sp_sum(x)
assert z.type.broadcastable==()
f = theano.function([x], z)
x_val = cast(x_data)
out = f(x_val)
assert out == x_val.sum()
# Sum on axis 0
try:
z = theano.sparse.sandbox.sp.sp_sum(x, axis=0)
assert z.type.broadcastable==(False,)
f = theano.function([x], z)
x_val = cast(x_data)
out = f(x_val)
assert (out == x_val.sum(axis=0)).all()
except NotImplementedError:
pass
# Sum on axis 1
try:
z = theano.sparse.sandbox.sp.sp_sum(x, axis=1)
assert z.type.broadcastable==(False,)
f = theano.function([x], z)
x_val = cast(x_data)
out = f(x_val)
expected = numpy.asarray(x_val.sum(axis=1)).reshape(x_val.shape[0])
assert (out == expected).all()
except NotImplementedError:
pass
def test_diagonal(): def test_diagonal():
for K in 1, 5: for K in 1, 5:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论