提交 d3a51f4f authored 作者: Valentin Bisson's avatar Valentin Bisson 提交者: Frederic

SpSum Op: pep8'ified, perform rolled back to basic (clearer/safer)…

SpSum Op: pep8'ified, perform rolled back to basic (clearer/safer) implementation, and some prints remnoved from tests.
上级 f699fa5e
...@@ -57,41 +57,34 @@ class SpSum(Op): ...@@ -57,41 +57,34 @@ class SpSum(Op):
assert isinstance(x.type, theano.sparse.SparseType) assert isinstance(x.type, theano.sparse.SparseType)
b = () b = ()
if self.axis is not None: if self.axis is not None:
b=(False,) b = (False,)
z = tensor.tensor(broadcastable=b, dtype=x.dtype) z = tensor.tensor(broadcastable=b, dtype=x.dtype)
return gof.Apply(self, [x], [z]) return gof.Apply(self, [x], [z])
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
r=None r = None
if self.axis is None: if self.axis is None:
r=[()] r = [()]
elif self.axis == 0: elif self.axis == 0:
r=[(shapes[0][1],)] r = [(shapes[0][1],)]
else: else:
r=[(shapes[0][0],)] r = [(shapes[0][0],)]
return r return r
def perform(self,node, (x,), (z,)): def perform(self,node, (x,), (z,)):
if self.axis is None: if self.axis is None:
z[0] = numpy.asarray(x.sum()) z[0] = numpy.asarray(x.sum())
else: else:
s = set(xrange(len(x.shape))) - set([self.axis]) if self.axis == 0:
myreshape = map((lambda i: x.shape[i]), s) if x.format == 'csc':
if x.format not in ('csc', 'csr'): z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[1],))
x = x.asformat(x.format) else:
z[0] = numpy.asarray(x.sum(axis = self.axis)).reshape(myreshape) z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[1],))
elif self.axis == 1:
#case by case code for reference if x.format == 'csr':
#if self.axis == 0: z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[0],))
# if x.format == 'csc': else:
# z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[1],)) z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[0],))
# else:
# z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[1],))
#if self.axis == 1:
# if x.format == 'csr':
# z[0] = numpy.asarray(x.sum(axis=self.axis)).reshape((x.shape[0],))
# else:
# z[0] = numpy.asarray(x.asformat(x.format).sum(axis=self.axis)).reshape((x.shape[0],))
def grad(self,(x,), (gz,)): def grad(self,(x,), (gz,)):
if self.axis is None: if self.axis is None:
......
...@@ -361,27 +361,22 @@ class TestSP(unittest.TestCase): ...@@ -361,27 +361,22 @@ class TestSP(unittest.TestCase):
utt.verify_grad(d, [kvals]) utt.verify_grad(d, [kvals])
def test_sp_sum(self): def test_sp_sum(self):
print '\n\n*************************************************'
print ' TEST SUM'
print '*************************************************'
from theano.sparse.sandbox.sp import SpSum from theano.sparse.sandbox.sp import SpSum
# TODO: test both grad. # TODO: test both grad.
rng = numpy.random.RandomState(42) rng = numpy.random.RandomState(42)
from theano.sparse.basic import SparseFromDense,DenseFromSparse from theano.sparse.basic import SparseFromDense,DenseFromSparse
cases = [("csc", scipy.sparse.csc_matrix), ("csr", scipy.sparse.csr_matrix)] cases = [("csc", scipy.sparse.csc_matrix), ("csr", scipy.sparse.csr_matrix)]
for format, cast in cases: for format, cast in cases:
print 'format: %(format)s'%locals() #print 'format: %(format)s' % locals()
x = theano.sparse.SparseType(format=format, x = theano.sparse.SparseType(format=format,
dtype=theano.config.floatX)() dtype=theano.config.floatX)()
x_data = numpy.arange(20).reshape(5,4).astype(theano.config.floatX) x_data = numpy.arange(20).reshape(5,4).astype(theano.config.floatX)
# Sum on all axis # Sum on all axis
print 'sum on all axis...' #print 'sum on all axis...'
z = theano.sparse.sandbox.sp.sp_sum(x) z = theano.sparse.sandbox.sp.sp_sum(x)
assert z.type.broadcastable == () assert z.type.broadcastable == ()
f = theano.function([x], z) f = theano.function([x], z)
...@@ -391,7 +386,7 @@ class TestSP(unittest.TestCase): ...@@ -391,7 +386,7 @@ class TestSP(unittest.TestCase):
assert out == expected assert out == expected
# Sum on axis 0 # Sum on axis 0
print 'sum on axis 0...' #print 'sum on axis 0...'
z = theano.sparse.sandbox.sp.sp_sum(x, axis=0) z = theano.sparse.sandbox.sp.sp_sum(x, axis=0)
assert z.type.broadcastable == (False,) assert z.type.broadcastable == (False,)
f = theano.function([x], z) f = theano.function([x], z)
...@@ -401,7 +396,7 @@ class TestSP(unittest.TestCase): ...@@ -401,7 +396,7 @@ class TestSP(unittest.TestCase):
assert (out == expected).all() assert (out == expected).all()
# Sum on axis 1 # Sum on axis 1
print 'sum on axis 1...' #print 'sum on axis 1...'
z = theano.sparse.sandbox.sp.sp_sum(x, axis=1) z = theano.sparse.sandbox.sp.sp_sum(x, axis=1)
assert z.type.broadcastable == (False,) assert z.type.broadcastable == (False,)
f = theano.function([x], z) f = theano.function([x], z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论