提交 5f8cee6b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test more FusionOptimizer graphs

上级 96122d15
......@@ -569,8 +569,6 @@ class FusionOptimizer(GraphRewriter):
return scalar_inputs, scalar_outputs
def apply(self, fgraph):
nb_replacement = 0
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
......@@ -925,6 +923,8 @@ class FusionOptimizer(GraphRewriter):
starting_nodes=starting_nodes,
)
nb_fused = 0
nb_replacement = 0
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
if (len(inputs) + len(outputs)) > max_operands:
warn(
......@@ -943,11 +943,13 @@ class FusionOptimizer(GraphRewriter):
if old_out.name:
composite_out.name = old_out.name
starting_nodes = len(fgraph.apply_nodes)
fgraph.replace_all_validate(
list(zip(outputs, composite_outputs, strict=True)),
reason=self.__class__.__name__,
)
nb_replacement += 1
nb_fused += 1
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
......@@ -965,7 +967,7 @@ class FusionOptimizer(GraphRewriter):
return (
self,
1, # nb_iter
nb_fused,
nb_replacement,
0, # nb_inconsintency_replace
validate_time,
......@@ -978,7 +980,7 @@ class FusionOptimizer(GraphRewriter):
def print_profile(stream, prof, level=0):
blanc = " " * level
print(blanc, "FusionOptimizer", file=stream)
print(blanc, " nb_iter", prof[1], file=stream)
print(blanc, " nb_fused", prof[1], file=stream)
print(blanc, " nb_replacement", prof[2], file=stream)
print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
print(blanc, " validate_time", prof[4], file=stream)
......
......@@ -273,7 +273,8 @@ class TestFusion:
fwx = fw + fx
ftanx = tan(fx)
def large_fuseable_graph(self, n):
@staticmethod
def large_fuseable_graph(n):
factors = []
sd = dscalar()
means = dvector()
......@@ -296,6 +297,28 @@ class TestFusion:
dlogp = [pytensor.grad(logp, v) for v in vars]
return vars, dlogp
@staticmethod
def deep_small_kernels(n):
x = pt.matrix("x")
out = x
for _ in range(n):
out = pt.sin(out.T) + pt.cos(out)
return [x], [out]
@staticmethod
def test_diamond_graph():
a = pt.matrix("a")
b = pt.exp(a)
c = pt.log(b)
d = pt.sin(c)
e = c + d
fg = FunctionGraph([a], [e], clone=False)
_, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg)
assert nb_fused == 1
assert nb_replacement == 4
@pytest.mark.parametrize(
"case",
[
......@@ -1347,16 +1370,26 @@ class TestFusion:
benchmark(func)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_rewrite_benchmark(self, benchmark):
inps, outs = self.large_fuseable_graph(n=25)
@pytest.mark.parametrize(
"graph_fn, n, expected_n_repl",
[
("deep_small_kernels", 20, (20, 60)),
("large_fuseable_graph", 25, (103, 876)),
],
)
def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark):
inps, outs = getattr(self, graph_fn)(n)
fg = FunctionGraph(inps, outs)
opt = FusionOptimizer()
def rewrite_func():
nb_replacement = opt.apply(fg.clone())[2]
return nb_replacement
fg_clone = fg.clone()
_, nb_fused, nb_replacement, *_ = opt.apply(fg_clone)
# fg_clone.dprint()
return nb_fused, nb_replacement
assert benchmark(rewrite_func) == 103
assert rewrite_func() == expected_n_repl
benchmark.pedantic(rewrite_func, rounds=7, iterations=5)
def test_no_warning_from_old_client(self):
# There used to be a warning issued when creating fuseable mapping
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论