提交 d5cb23a5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Exclude unnecessary inputs in useless_composite rewrite

上级 316cfd1d
......@@ -990,23 +990,33 @@ if config.tensor__local_elemwise_fusion:
@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
def local_useless_composite(fgraph, node):
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
"""
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite
):
return
comp = node.op.scalar_op
idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
if len(idx) < len(node.outputs):
new_outputs = [comp.outputs[i] for i in idx]
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
return dict(zip([node.outputs[i] for i in idx], e))
used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
comp_fgraph = FunctionGraph(
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
)
used_inputs_idxs = [
i
for i, i_intern in enumerate(comp_fgraph.inputs)
if comp_fgraph.clients[i_intern]
]
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
node.outputs
):
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
c = aes.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e))
@node_rewriter([CAReduce])
......
......@@ -1292,22 +1292,37 @@ class TestCompositeCodegen:
def test_local_useless_composite(self):
x = aes.float32()
c = aes.Composite([x], [x + 1, x - 1])
X = matrix()
o = Elemwise(scalar_op=c)(X)
y = aes.float32()
z = aes.float32()
c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")
f = function([X], o[0], mode=mode)
f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])
f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[2.0]])
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
f = function([X], o[1], mode=mode)
f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[0.0]])
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
def test_local_useless_dimshuffle_makevector():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论