提交 83356200 authored 作者: Valentin Bisson's avatar Valentin Bisson

Merge test_sp.py.

...@@ -438,10 +438,11 @@ def test_remove0(): ...@@ -438,10 +438,11 @@ def test_remove0():
print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals() print 'config: format=\'%(format)s\', matrix_class=%(matrix_class)s'%locals()
# real # real
origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX) origin = (numpy.arange(9) + 1).reshape((3, 3)).astype(theano.config.floatX)
with0 = matrix_class(origin).astype(theano.config.floatX) mat = matrix_class(origin).astype(theano.config.floatX)
with0[0,1] = with0[1,0] = with0[2,2] = 0 mat[0,1] = mat[1,0] = mat[2,2] = 0
assert with0.size == 9
assert mat.size == 9
# symbolic # symbolic
x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)() x = theano.sparse.SparseType(format=format, dtype=theano.config.floatX)()
...@@ -456,9 +457,9 @@ def test_remove0(): ...@@ -456,9 +457,9 @@ def test_remove0():
# checking # checking
# makes sense to change its name # makes sense to change its name
target = with0 target = mat
result = f(with0) result = f(mat)
with0.eliminate_zeros() mat.eliminate_zeros()
assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?' assert result.size == target.size, 'Matrices sizes differ. Have zeros been removed ?'
def test_diagonal(): def test_diagonal():
......
...@@ -1951,6 +1951,7 @@ compile.optdb.register('local_inplace_incsubtensor1', ...@@ -1951,6 +1951,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_remove0(node): def local_inplace_remove0(node):
""" """
Optimization to insert inplace versions of Remove0.
""" """
if isinstance(node.op, theano.sparse.sandbox.sp.Remove0) and not node.op.inplace: if isinstance(node.op, theano.sparse.sandbox.sp.Remove0) and not node.op.inplace:
new_op = node.op.__class__(inplace=True) new_op = node.op.__class__(inplace=True)
...@@ -1960,7 +1961,7 @@ def local_inplace_remove0(node): ...@@ -1960,7 +1961,7 @@ def local_inplace_remove0(node):
compile.optdb.register('local_inplace_remove0', compile.optdb.register('local_inplace_remove0',
TopoOptimizer(local_inplace_remove0, TopoOptimizer(local_inplace_remove0,
failure_callback=TopoOptimizer.warn_inplace), 60, failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG 'fast_run', 'inplace')
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论