提交 3b402c0c authored 作者: Frederic Bastien's avatar Frederic Bastien

added an optimization that remove useless join and test it.

上级 18d3ec48
......@@ -712,6 +712,13 @@ def local_track_shape_i(node):
replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Join])
def local_useless_join(node):
if isinstance(node.op, T.Join) and len(node.inputs)==2:
return [node.inputs[1]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Subtensor])
......
......@@ -21,7 +21,9 @@ from theano import pprint, shared
from theano.tests import unittest_tools as utt
from theano import function, compile
mode_opt = theano.config.mode
if mode_opt == 'FAST_COMPILE':
mode_opt = 'FAST_RUN'
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = TensorType(broadcastable = xbc, dtype = 'float64')('x')
......@@ -2092,6 +2094,45 @@ def test_make_vector():
except AssertionError:
pass
def test_local_useless_join():
#test for vector
a = TT.vector('a')
s = stack(a)
f = function([a], s, mode=mode_opt)
val = f([1])
assert numpy.all(val == [1])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
#test for matrix join(0,a)
a = TT.matrix('a')
s = join(0,a)
f = function([a], s, mode=mode_opt)
val = f([[1]])
assert numpy.all(val == [[1]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
#test for matrix join(1,a)
s = join(1,a)
f = function([a], s, mode=mode_opt)
val = f([[1]])
assert numpy.all(val == [[1]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
#test we don't apply when their is 2 inputs
s = join(1,a,a)
f = function([a], s, mode=mode_opt)
val = f([[1]])
assert numpy.all(val == [[1]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert f.maker.env.outputs[0].dtype == config.floatX
if __name__ == '__main__':
# unittest.main()
test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论