提交 3b402c0c authored 作者: Frederic Bastien's avatar Frederic Bastien

added an optimization that remove useless join and test it.

上级 18d3ec48
...@@ -712,6 +712,13 @@ def local_track_shape_i(node): ...@@ -712,6 +712,13 @@ def local_track_shape_i(node):
replacement = shape_feature.scheduled[node] replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Join])
def local_useless_join(node):
if isinstance(node.op, T.Join) and len(node.inputs)==2:
return [node.inputs[1]]
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Subtensor]) @gof.local_optimizer([T.Subtensor])
......
...@@ -21,7 +21,9 @@ from theano import pprint, shared ...@@ -21,7 +21,9 @@ from theano import pprint, shared
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import function, compile from theano import function, compile
mode_opt = theano.config.mode
if mode_opt == 'FAST_COMPILE':
mode_opt = 'FAST_RUN'
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = TensorType(broadcastable = xbc, dtype = 'float64')('x') x = TensorType(broadcastable = xbc, dtype = 'float64')('x')
...@@ -2092,6 +2094,45 @@ def test_make_vector(): ...@@ -2092,6 +2094,45 @@ def test_make_vector():
except AssertionError: except AssertionError:
pass pass
def test_local_useless_join():
#test for vector
a = TT.vector('a')
s = stack(a)
f = function([a], s, mode=mode_opt)
val = f([1])
assert numpy.all(val == [1])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
#test for matrix join(0,a)
a = TT.matrix('a')
s = join(0,a)
f = function([a], s, mode=mode_opt)
val = f([[1]])
assert numpy.all(val == [[1]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
#test for matrix join(1,a)
s = join(1,a)
f = function([a], s, mode=mode_opt)
val = f([[1]])
assert numpy.all(val == [[1]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
#test we don't apply when their is 2 inputs
s = join(1,a,a)
f = function([a], s, mode=mode_opt)
val = f([[1]])
assert numpy.all(val == [[1]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert f.maker.env.outputs[0].dtype == config.floatX
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论