提交 d3a51f4f authored 作者: Valentin Bisson's avatar Valentin Bisson 提交者: Frederic

SpSum Op: pep8'ified, perform rolled back to basic (clearer/safer)…

SpSum Op: pep8'ified, perform rolled back to basic (clearer/safer) implementation, and some prints remnoved from tests.
上级 f699fa5e
......@@ -57,41 +57,34 @@ class SpSum(Op):
assert isinstance(x.type, theano.sparse.SparseType)
b = ()
if self.axis is not None:
b=(False,)
b = (False,)
z = tensor.tensor(broadcastable=b, dtype=x.dtype)
return gof.Apply(self, [x], [z])
def infer_shape(self, node, shapes):
r=None
r = None
if self.axis is None:
r=[()]
r = [()]
elif self.axis == 0:
r=[(shapes[0][1],)]
r = [(shapes[0][1],)]
else:
r=[(shapes[0][0],)]
r = [(shapes[0][0],)]
return r
def perform(self,node, (x,), (z,)):
if self.axis is None:
z[0] = numpy.asarray(x.sum())
else:
s = set(xrange(len(x.shape))) - set([self.axis])
myreshape = map((lambda i: x.shape[i]), s)
if x.format not in ('csc', 'csr'):
x = x.asformat(x.format)
z[0] = numpy.asarray(x.sum(axis = self.axis)).reshape(myreshape)
#case by case code for reference
#if self.axis == 0:
# if x.format == 'csc':
# z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[1],))
# else:
# z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[1],))
#if self.axis == 1:
# if x.format == 'csr':
# z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[0],))
# else:
# z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[0],))
if self.axis == 0:
if x.format == 'csc':
z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[1],))
else:
z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[1],))
elif self.axis == 1:
if x.format == 'csr':
z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[0],))
else:
z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[0],))
def grad(self,(x,), (gz,)):
if self.axis is None:
......
......@@ -361,11 +361,6 @@ class TestSP(unittest.TestCase):
utt.verify_grad(d, [kvals])
def test_sp_sum(self):
print '\n\n*************************************************'
print ' TEST SUM'
print '*************************************************'
from theano.sparse.sandbox.sp import SpSum
# TODO: test both grad.
......@@ -375,13 +370,13 @@ class TestSP(unittest.TestCase):
for format, cast in cases:
print 'format: %(format)s'%locals()
#print 'format: %(format)s' % locals()
x = theano.sparse.SparseType(format=format,
dtype=theano.config.floatX)()
x_data = numpy.arange(20).reshape(5,4).astype(theano.config.floatX)
# Sum on all axis
print 'sum on all axis...'
#print 'sum on all axis...'
z = theano.sparse.sandbox.sp.sp_sum(x)
assert z.type.broadcastable == ()
f = theano.function([x], z)
......@@ -391,7 +386,7 @@ class TestSP(unittest.TestCase):
assert out == expected
# Sum on axis 0
print 'sum on axis 0...'
#print 'sum on axis 0...'
z = theano.sparse.sandbox.sp.sp_sum(x, axis=0)
assert z.type.broadcastable == (False,)
f = theano.function([x], z)
......@@ -401,7 +396,7 @@ class TestSP(unittest.TestCase):
assert (out == expected).all()
# Sum on axis 1
print 'sum on axis 1...'
#print 'sum on axis 1...'
z = theano.sparse.sandbox.sp.sp_sum(x, axis=1)
assert z.type.broadcastable == (False,)
f = theano.function([x], z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论