提交 e20dd0b6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove `tensor__local_elemwise_fusion` config.

Same behavior can be obtained with `optimizer_excluding` The `local_careduce_rewrite` is now included in this database. Otherwise it would usually not be applied because it ran before the fusion rewrites
上级 671cb44b
...@@ -640,16 +640,6 @@ def add_tensor_configvars(): ...@@ -640,16 +640,6 @@ def add_tensor_configvars():
in_c_key=False, in_c_key=False,
) )
config.add(
"tensor__local_elemwise_fusion",
(
"Enable or not in fast_run mode(fast_run optimization) the elemwise "
"fusion optimization"
),
BoolParam(True),
in_c_key=False,
)
# http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx # http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx
config.add( config.add(
"lib__amblibm", "lib__amblibm",
......
...@@ -1085,38 +1085,10 @@ class FusionOptimizer(GraphRewriter): ...@@ -1085,38 +1085,10 @@ class FusionOptimizer(GraphRewriter):
print(blanc, " time_toposort", prof[7], file=stream) print(blanc, " time_toposort", prof[7], file=stream)
if config.tensor__local_elemwise_fusion:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
)
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_useless_composite(fgraph, node): def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere.""" """Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not isinstance(node.op, Elemwise) or not isinstance( if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite node.op.scalar_op, aes.Composite
...@@ -1231,11 +1203,45 @@ def local_careduce_fusion(fgraph, node): ...@@ -1231,11 +1203,45 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)] return [new_car_op(*elm_inputs)]
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register( compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
)
fuse_seqopt.register(
"local_useless_composite_outputs",
in2out(local_useless_composite_outputs),
"fast_run",
"fusion",
position=2,
)
fuse_seqopt.register(
"local_careduce_fusion", "local_careduce_fusion",
in2out(local_careduce_fusion), in2out(local_careduce_fusion),
"fast_run",
"fusion", "fusion",
position=49, position=10,
) )
......
...@@ -1425,39 +1425,40 @@ class TestCompositeCodegen: ...@@ -1425,39 +1425,40 @@ class TestCompositeCodegen:
fval = f([1, 2, 3]) fval = f([1, 2, 3])
assert np.all(fval == [6, 12, 18]) assert np.all(fval == [6, 12, 18])
def test_local_useless_composite(self):
x = aes.float32()
y = aes.float32()
z = aes.float32()
c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")
f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])
f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
f = function([X, Y, Z], o2, mode=mode) def test_local_useless_composite_outputs():
topo = f.maker.fgraph.toposort() x = aes.float32()
assert len(topo) == 1 y = aes.float32()
assert len(topo[0].inputs) == 1 z = aes.float32()
assert len(topo[0].outputs) == 1 c = aes.Composite([x, y, z], [x + 1, y - 1])
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")
f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])
f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
def test_local_useless_dimshuffle_makevector(): def test_local_useless_dimshuffle_makevector():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论