提交 8188f49f authored 作者: Frederic Bastien's avatar Frederic Bastien

make Shape op support sparse type and make the shape optimizer handle them correctly.

上级 9bd5b59a
......@@ -441,6 +441,26 @@ class test_structureddot(unittest.TestCase):
self.failUnless(numpy.allclose(theano_result, scipy_result))
self.failIf(theano_time > overhead_rtol*scipy_time + overhead_tol)
def test_shape_i():
sparse_dtype = 'float32'
a = SparseType('csr', dtype=sparse_dtype)()
f = theano.function([a], a.shape[1], mode='FAST_RUN')
assert f(sp.csr_matrix(random_lil((100,10), sparse_dtype, 3)))==(10)
def test_shape():
sparse_dtype = 'float32'
a = SparseType('csr', dtype=sparse_dtype)()
f = theano.function([a], a.shape, mode='FAST_RUN')
assert numpy.all(f(sp.csr_matrix(random_lil((100,10), sparse_dtype, 3)))==(100,10))
if theano.config.mode!='FAST_COMPILE':
topo = f.maker.env.toposort()
assert len(topo)==3
assert isinstance(topo[0].op,tensor.opt.Shape_i)
assert isinstance(topo[1].op,tensor.opt.Shape_i)
assert isinstance(topo[2].op,tensor.opt.MakeVector)
import theano.tensor.tests.test_sharedvar
test_shared_options=theano.tensor.tests.test_sharedvar.makeSharedTester(
theano.sparse.shared, 'float64',
......
......@@ -465,7 +465,7 @@ class ShapeFeature(object):
"""
def shape_i(self, i):
def op_deco(r):
if r.type.broadcastable[i]:
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one
else:
return Shape_i(i)(r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论