提交 2553b8fd authored 作者: Frederic Bastien's avatar Frederic Bastien

Add opt local_useless_composite that remove Composite outputs that aren't needed.

上级 b4f4a23b
...@@ -7335,6 +7335,26 @@ else: ...@@ -7335,6 +7335,26 @@ else:
'FusionOptimizer') 'FusionOptimizer')
@register_canonicalize
@gof.local_optimizer([Elemwise])
def local_useless_composite(node):
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
"""
if (not isinstance(node.op, Elemwise) or
not isinstance(node.op.scalar_op, scalar.Composite)):
return
comp = node.op.scalar_op
idx = [i for i, o_extern in enumerate(node.outputs)
if o_extern.clients]
if len(idx) < len(node.outputs):
new_outputs = [comp.outputs[i] for i in idx]
c = scalar.Composite(inputs=comp.inputs,
outputs=new_outputs)
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
return dict(zip([node.outputs[i] for i in idx], e))
# ############################ # ############################
# # Remove consider_constant # # # Remove consider_constant #
# ############################ # ############################
......
...@@ -1526,6 +1526,26 @@ class TestCompositeCodegen(unittest.TestCase): ...@@ -1526,6 +1526,26 @@ class TestCompositeCodegen(unittest.TestCase):
fval = numpy.asarray(f([1, 2, 3])) fval = numpy.asarray(f([1, 2, 3]))
assert numpy.all(fval == [6, 12, 18]), fval assert numpy.all(fval == [6, 12, 18]), fval
def test_local_useless_composite(self):
x = theano.scalar.float32()
c = theano.scalar.Composite([x], [x+1, x-1])
X = theano.tensor.matrix()
o = theano.tensor.Elemwise(scalar_op=c)(X)
mode = theano.compile.mode.get_default_mode().including(
'local_useless_composite')
f = theano.function([X], o[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.]]), [[2.]])
f = theano.function([X], o[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.]]), [[0.]])
def test_log1p(): def test_log1p():
m = theano.config.mode m = theano.config.mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论