提交 74e88eb9 authored 作者: Valentin Bisson's avatar Valentin Bisson

CCW#37: Fixed checking that inplace optimization is applied when possible, and…

CCW#37: Fixed checking that inplace optimization is applied when possible, and updated Remove0 sparse Op's __str__.
上级 81813e95
...@@ -236,10 +236,10 @@ class Remove0(Op): ...@@ -236,10 +236,10 @@ class Remove0(Op):
return 64153 ^ hash(type(self)) ^ hash(self.inplace) return 64153 ^ hash(type(self)) ^ hash(self.inplace)
def __str__(self): def __str__(self):
s = self.__class__.__name__ l = []
if self.inplace: if self.inplace:
s += '{inplace}' l.append('inplace')
return s return self.__class__.__name__+'{%s}'%', '.join(l)
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
...@@ -251,7 +251,6 @@ class Remove0(Op): ...@@ -251,7 +251,6 @@ class Remove0(Op):
c = x.copy() c = x.copy()
c.eliminate_zeros() c.eliminate_zeros()
z[0] = c z[0] = c
return
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [gz] return [gz]
......
...@@ -434,11 +434,10 @@ def test_remove0(): ...@@ -434,11 +434,10 @@ def test_remove0():
('csc',scipy.sparse.csc_matrix), ('csc',scipy.sparse.csc_matrix),
('csr',scipy.sparse.csr_matrix), ('csr',scipy.sparse.csr_matrix),
] ]
for format,matrix_class in configs: for format,matrix_class in configs:
print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals() print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals()
# real # real
origin = (numpy.arange(9)+1).reshape((3,3)).astype(theano.config.floatX) origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX)
with0 = matrix_class(origin).astype(theano.config.floatX) with0 = matrix_class(origin).astype(theano.config.floatX)
with0[0,1] = with0[1,0] = with0[2,2] = 0 with0[0,1] = with0[1,0] = with0[2,2] = 0
...@@ -446,7 +445,15 @@ def test_remove0(): ...@@ -446,7 +445,15 @@ def test_remove0():
# symbolic # symbolic
x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)() x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)()
f = theano.function([x], sp.Remove0()(x)) # the In thingy has to be there because theano has as rule to optimize inputs
f = theano.function([theano.In(x, borrow=True, mutable=True)], sp.Remove0()(x))
# assert optimization is applied
# list of apply nodes in the optimized graph.
nodes = f.maker.env.toposort()
v = [True for node in nodes if isinstance(node.op, sp.Remove0) and node.op.inplace]
if v:
assert any(v)
# checking # checking
# makes sense to change its name # makes sense to change its name
...@@ -455,7 +462,6 @@ def test_remove0(): ...@@ -455,7 +462,6 @@ def test_remove0():
with0.eliminate_zeros() with0.eliminate_zeros()
assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?' assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?'
def test_diagonal(): def test_diagonal():
for K in 1, 5: for K in 1, 5:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论