提交 5eff2ddf authored 作者: onze's avatar onze

CCW#37: minor changes to comply with pep8 and docstrings.

上级 13a24f7f
...@@ -438,10 +438,11 @@ def test_remove0(): ...@@ -438,10 +438,11 @@ def test_remove0():
print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals() print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals()
# real # real
origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX) origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX)
with0 = matrix_class(origin).astype(theano.config.floatX) mat = matrix_class(origin).astype(theano.config.floatX)
with0[0,1] = with0[1,0] = with0[2,2] = 0 mat[0,1] = mat[1,0] = mat[2,2] = 0
assert with0.size == 9
assert mat.size == 9
# symbolic # symbolic
x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)() x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)()
...@@ -457,9 +458,9 @@ def test_remove0(): ...@@ -457,9 +458,9 @@ def test_remove0():
# checking # checking
# makes sense to change its name # makes sense to change its name
target = with0 target = mat
result = f(with0) result = f(mat)
with0.eliminate_zeros() mat.eliminate_zeros()
assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?' assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?'
def test_diagonal(): def test_diagonal():
...@@ -567,6 +568,9 @@ def test_col_scale(): ...@@ -567,6 +568,9 @@ def test_col_scale():
print >> sys.stderr, "WARNING: skipping gradient test because verify_grad doesn't support sparse arguments" print >> sys.stderr, "WARNING: skipping gradient test because verify_grad doesn't support sparse arguments"
if __name__ == '__main__': if __name__ == '__main__':
if 1:
test_remove0()
exit()
if 1: if 1:
testcase = TestSP testcase = TestSP
suite = unittest.TestLoader() suite = unittest.TestLoader()
......
...@@ -1951,6 +1951,7 @@ compile.optdb.register('local_inplace_incsubtensor1', ...@@ -1951,6 +1951,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_remove0(node): def local_inplace_remove0(node):
""" """
Optimization to insert inplace versions of Remove0.
""" """
if isinstance(node.op, theano.sparse.sandbox.sp.Remove0) and not node.op.inplace: if isinstance(node.op, theano.sparse.sandbox.sp.Remove0) and not node.op.inplace:
new_op = node.op.__class__(inplace=True) new_op = node.op.__class__(inplace=True)
...@@ -1960,7 +1961,7 @@ def local_inplace_remove0(node): ...@@ -1960,7 +1961,7 @@ def local_inplace_remove0(node):
compile.optdb.register('local_inplace_remove0', compile.optdb.register('local_inplace_remove0',
TopoOptimizer(local_inplace_remove0, TopoOptimizer(local_inplace_remove0,
failure_callback=TopoOptimizer.warn_inplace), 60, failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG 'fast_run', 'inplace')
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论