提交 5f8cee6b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test more FusionOptimizer graphs

上级 96122d15
...@@ -569,8 +569,6 @@ class FusionOptimizer(GraphRewriter): ...@@ -569,8 +569,6 @@ class FusionOptimizer(GraphRewriter):
return scalar_inputs, scalar_outputs return scalar_inputs, scalar_outputs
def apply(self, fgraph): def apply(self, fgraph):
nb_replacement = 0
if fgraph.profile: if fgraph.profile:
validate_before = fgraph.profile.validate_time validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
...@@ -925,6 +923,8 @@ class FusionOptimizer(GraphRewriter): ...@@ -925,6 +923,8 @@ class FusionOptimizer(GraphRewriter):
starting_nodes=starting_nodes, starting_nodes=starting_nodes,
) )
nb_fused = 0
nb_replacement = 0
for inputs, outputs in find_next_fuseable_subgraph(fgraph): for inputs, outputs in find_next_fuseable_subgraph(fgraph):
if (len(inputs) + len(outputs)) > max_operands: if (len(inputs) + len(outputs)) > max_operands:
warn( warn(
...@@ -943,11 +943,13 @@ class FusionOptimizer(GraphRewriter): ...@@ -943,11 +943,13 @@ class FusionOptimizer(GraphRewriter):
if old_out.name: if old_out.name:
composite_out.name = old_out.name composite_out.name = old_out.name
starting_nodes = len(fgraph.apply_nodes)
fgraph.replace_all_validate( fgraph.replace_all_validate(
list(zip(outputs, composite_outputs, strict=True)), list(zip(outputs, composite_outputs, strict=True)),
reason=self.__class__.__name__, reason=self.__class__.__name__,
) )
nb_replacement += 1 nb_fused += 1
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
if fgraph.profile: if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before validate_time = fgraph.profile.validate_time - validate_before
...@@ -965,7 +967,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -965,7 +967,7 @@ class FusionOptimizer(GraphRewriter):
return ( return (
self, self,
1, # nb_iter nb_fused,
nb_replacement, nb_replacement,
0, # nb_inconsintency_replace 0, # nb_inconsintency_replace
validate_time, validate_time,
...@@ -978,7 +980,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -978,7 +980,7 @@ class FusionOptimizer(GraphRewriter):
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
blanc = " " * level blanc = " " * level
print(blanc, "FusionOptimizer", file=stream) print(blanc, "FusionOptimizer", file=stream)
print(blanc, " nb_iter", prof[1], file=stream) print(blanc, " nb_fused", prof[1], file=stream)
print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_replacement", prof[2], file=stream)
print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
print(blanc, " validate_time", prof[4], file=stream) print(blanc, " validate_time", prof[4], file=stream)
......
...@@ -273,7 +273,8 @@ class TestFusion: ...@@ -273,7 +273,8 @@ class TestFusion:
fwx = fw + fx fwx = fw + fx
ftanx = tan(fx) ftanx = tan(fx)
def large_fuseable_graph(self, n): @staticmethod
def large_fuseable_graph(n):
factors = [] factors = []
sd = dscalar() sd = dscalar()
means = dvector() means = dvector()
...@@ -296,6 +297,28 @@ class TestFusion: ...@@ -296,6 +297,28 @@ class TestFusion:
dlogp = [pytensor.grad(logp, v) for v in vars] dlogp = [pytensor.grad(logp, v) for v in vars]
return vars, dlogp return vars, dlogp
@staticmethod
def deep_small_kernels(n):
x = pt.matrix("x")
out = x
for _ in range(n):
out = pt.sin(out.T) + pt.cos(out)
return [x], [out]
@staticmethod
def test_diamond_graph():
a = pt.matrix("a")
b = pt.exp(a)
c = pt.log(b)
d = pt.sin(c)
e = c + d
fg = FunctionGraph([a], [e], clone=False)
_, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg)
assert nb_fused == 1
assert nb_replacement == 4
@pytest.mark.parametrize( @pytest.mark.parametrize(
"case", "case",
[ [
...@@ -1347,16 +1370,26 @@ class TestFusion: ...@@ -1347,16 +1370,26 @@ class TestFusion:
benchmark(func) benchmark(func)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler") @pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_rewrite_benchmark(self, benchmark): @pytest.mark.parametrize(
inps, outs = self.large_fuseable_graph(n=25) "graph_fn, n, expected_n_repl",
[
("deep_small_kernels", 20, (20, 60)),
("large_fuseable_graph", 25, (103, 876)),
],
)
def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark):
inps, outs = getattr(self, graph_fn)(n)
fg = FunctionGraph(inps, outs) fg = FunctionGraph(inps, outs)
opt = FusionOptimizer() opt = FusionOptimizer()
def rewrite_func(): def rewrite_func():
nb_replacement = opt.apply(fg.clone())[2] fg_clone = fg.clone()
return nb_replacement _, nb_fused, nb_replacement, *_ = opt.apply(fg_clone)
# fg_clone.dprint()
return nb_fused, nb_replacement
assert benchmark(rewrite_func) == 103 assert rewrite_func() == expected_n_repl
benchmark.pedantic(rewrite_func, rounds=7, iterations=5)
def test_no_warning_from_old_client(self): def test_no_warning_from_old_client(self):
# There used to be a warning issued when creating fuseable mapping # There used to be a warning issued when creating fuseable mapping
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论