提交 d5cb23a5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Exclude unnecessary inputs in useless_composite rewrite

上级 316cfd1d
...@@ -990,23 +990,33 @@ if config.tensor__local_elemwise_fusion: ...@@ -990,23 +990,33 @@ if config.tensor__local_elemwise_fusion:
@register_canonicalize @register_canonicalize
@register_specialize
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_useless_composite(fgraph, node): def local_useless_composite(fgraph, node):
"""For elemwise Composite that have multiple outputs, remove the """Remove inputs and outputs of Composite Ops that are not used anywhere."""
outputs that are not used.
"""
if not isinstance(node.op, Elemwise) or not isinstance( if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite node.op.scalar_op, aes.Composite
): ):
return return
comp = node.op.scalar_op comp = node.op.scalar_op
idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
if len(idx) < len(node.outputs): used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
new_outputs = [comp.outputs[i] for i in idx] comp_fgraph = FunctionGraph(
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) inputs=comp.inputs, outputs=used_inner_outputs, clone=False
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) )
return dict(zip([node.outputs[i] for i in idx], e)) used_inputs_idxs = [
i
for i, i_intern in enumerate(comp_fgraph.inputs)
if comp_fgraph.clients[i_intern]
]
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
node.outputs
):
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
c = aes.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e))
@node_rewriter([CAReduce]) @node_rewriter([CAReduce])
......
...@@ -1292,22 +1292,37 @@ class TestCompositeCodegen: ...@@ -1292,22 +1292,37 @@ class TestCompositeCodegen:
def test_local_useless_composite(self): def test_local_useless_composite(self):
x = aes.float32() x = aes.float32()
c = aes.Composite([x], [x + 1, x - 1]) y = aes.float32()
X = matrix() z = aes.float32()
o = Elemwise(scalar_op=c)(X) c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite") mode = get_default_mode().including("local_useless_composite")
f = function([X], o[0], mode=mode) f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])
f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1 assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[2.0]]) utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
f = function([X], o[1], mode=mode) f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1 assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[0.0]]) utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
def test_local_useless_dimshuffle_makevector(): def test_local_useless_dimshuffle_makevector():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论