提交 5eff2ddf authored 作者: onze's avatar onze

CCW#37: minor changes to comply with pep8 and docstrings.

上级 13a24f7f
...@@ -240,7 +240,7 @@ class Remove0(Op): ...@@ -240,7 +240,7 @@ class Remove0(Op):
if self.inplace: if self.inplace:
l.append('inplace') l.append('inplace')
return self.__class__.__name__+'{%s}'%', '.join(l) return self.__class__.__name__+'{%s}'%', '.join(l)
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
......
...@@ -438,11 +438,12 @@ def test_remove0(): ...@@ -438,11 +438,12 @@ def test_remove0():
print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals() print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals()
# real # real
origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX) origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX)
with0 = matrix_class(origin).astype(theano.config.floatX) mat = matrix_class(origin).astype(theano.config.floatX)
with0[0,1] = with0[1,0] = with0[2,2] = 0 mat[0,1] = mat[1,0] = mat[2,2] = 0
assert with0.size == 9
assert mat.size == 9
# symbolic # symbolic
x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)() x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)()
# the In thingy has to be there because theano has as rule not to optimize inputs # the In thingy has to be there because theano has as rule not to optimize inputs
...@@ -454,12 +455,12 @@ def test_remove0(): ...@@ -454,12 +455,12 @@ def test_remove0():
v = [True for node in nodes if isinstance(node.op, sp.Remove0) and node.op.inplace] v = [True for node in nodes if isinstance(node.op, sp.Remove0) and node.op.inplace]
if v: if v:
assert any(v) assert any(v)
# checking # checking
# makes sense to change its name # makes sense to change its name
target = with0 target = mat
result = f(with0) result = f(mat)
with0.eliminate_zeros() mat.eliminate_zeros()
assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?' assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?'
def test_diagonal(): def test_diagonal():
...@@ -567,6 +568,9 @@ def test_col_scale(): ...@@ -567,6 +568,9 @@ def test_col_scale():
print >> sys.stderr, "WARNING: skipping gradient test because verify_grad doesn't support sparse arguments" print >> sys.stderr, "WARNING: skipping gradient test because verify_grad doesn't support sparse arguments"
if __name__ == '__main__': if __name__ == '__main__':
if 1:
test_remove0()
exit()
if 1: if 1:
testcase = TestSP testcase = TestSP
suite = unittest.TestLoader() suite = unittest.TestLoader()
......
...@@ -731,7 +731,7 @@ class ShapeFeature(object): ...@@ -731,7 +731,7 @@ class ShapeFeature(object):
def default_infer_shape(self, node, i_shapes): def default_infer_shape(self, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node. """Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape. This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter, Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it. but this default implementation ignores it.
...@@ -746,7 +746,7 @@ class ShapeFeature(object): ...@@ -746,7 +746,7 @@ class ShapeFeature(object):
def unpack(self, s_i): def unpack(self, s_i):
"""Return a symbolic integer scalar for the shape element s_i. """Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass. The s_i argument was produced by the infer_shape() of an Op subclass.
""" """
# unpack the s_i that the Op returned # unpack the s_i that the Op returned
...@@ -777,7 +777,7 @@ class ShapeFeature(object): ...@@ -777,7 +777,7 @@ class ShapeFeature(object):
def set_shape(self, r, s): def set_shape(self, r, s):
"""Assign the shape `s` to previously un-shaped variable `r`. """Assign the shape `s` to previously un-shaped variable `r`.
:type r: a variable :type r: a variable
:type s: None or a tuple of symbolic integers :type s: None or a tuple of symbolic integers
""" """
...@@ -1951,6 +1951,7 @@ compile.optdb.register('local_inplace_incsubtensor1', ...@@ -1951,6 +1951,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_remove0(node): def local_inplace_remove0(node):
""" """
Optimization to insert inplace versions of Remove0.
""" """
if isinstance(node.op, theano.sparse.sandbox.sp.Remove0) and not node.op.inplace: if isinstance(node.op, theano.sparse.sandbox.sp.Remove0) and not node.op.inplace:
new_op = node.op.__class__(inplace=True) new_op = node.op.__class__(inplace=True)
...@@ -1960,7 +1961,7 @@ def local_inplace_remove0(node): ...@@ -1960,7 +1961,7 @@ def local_inplace_remove0(node):
compile.optdb.register('local_inplace_remove0', compile.optdb.register('local_inplace_remove0',
TopoOptimizer(local_inplace_remove0, TopoOptimizer(local_inplace_remove0,
failure_callback=TopoOptimizer.warn_inplace), 60, failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG 'fast_run', 'inplace')
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论