提交 6af49652 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update verbose optimizer format and print individual steps in LocalOptGroup

上级 3e00b5cd
...@@ -596,7 +596,9 @@ class ReplaceValidate(History, Validator): ...@@ -596,7 +596,9 @@ class ReplaceValidate(History, Validator):
except Exception as e: except Exception as e:
fgraph.revert(chk) fgraph.revert(chk)
if verbose: if verbose:
print(f"validate failed on node {r}.\n Reason: {reason}, {e}") print(
f"optimizer: validate failed on node {r}.\n Reason: {reason}, {e}"
)
raise raise
if config.scan__debug: if config.scan__debug:
from aesara.scan.op import Scan from aesara.scan.op import Scan
...@@ -618,7 +620,7 @@ class ReplaceValidate(History, Validator): ...@@ -618,7 +620,7 @@ class ReplaceValidate(History, Validator):
"Scan removed", nb, nb2, getattr(reason, "name", reason), r, new_r "Scan removed", nb, nb2, getattr(reason, "name", reason), r, new_r
) )
if verbose: if verbose:
print(reason, r, new_r) print(f"optimizer: rewrite {reason} replaces {r} with {new_r}")
# The return is needed by replace_all_validate_remove # The return is needed by replace_all_validate_remove
return chk return chk
......
...@@ -510,7 +510,7 @@ class FunctionGraph(MetaObject): ...@@ -510,7 +510,7 @@ class FunctionGraph(MetaObject):
if verbose is None: if verbose is None:
verbose = config.optimizer_verbose verbose = config.optimizer_verbose
if verbose: if verbose:
print(reason, var, new_var) print(f"optimizer: rewrite {reason} replaces {var} with {new_var}")
new_var = var.type.filter_variable(new_var, allow_convert=True) new_var = var.type.filter_variable(new_var, allow_convert=True)
......
...@@ -1347,6 +1347,10 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1347,6 +1347,10 @@ class LocalOptGroup(LocalOptimizer):
new_vars = new_repl new_vars = new_repl
else: # It must be a dict else: # It must be a dict
new_vars = list(new_repl.values()) new_vars = list(new_repl.values())
if config.optimizer_verbose:
print(f"optimizer: rewrite {opt} replaces {node} with {new_repl}")
if self.profile: if self.profile:
self.node_created[opt] += len( self.node_created[opt] += len(
list(applys_between(fgraph.variables, new_vars)) list(applys_between(fgraph.variables, new_vars))
......
import pytest
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.features import NodeFinder from aesara.graph.features import Feature, NodeFinder, ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from tests.graph.utils import MyVariable, op1
class TestNodeFinder: class TestNodeFinder:
...@@ -82,3 +85,35 @@ class TestNodeFinder: ...@@ -82,3 +85,35 @@ class TestNodeFinder:
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([t for t in g.get_nodes(type)]) == num: if not len([t for t in g.get_nodes(type)]) == num:
raise Exception("Expected: %i times %s" % (num, type)) raise Exception("Expected: %i times %s" % (num, type))
class TestReplaceValidate:
def test_verbose(self, capsys):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
fg = FunctionGraph([var1, var2], [var3], clone=False)
rv_feature = ReplaceValidate()
fg.attach_feature(rv_feature)
rv_feature.replace_all_validate(
fg, [(var3, var1)], reason="test-reason", verbose=True
)
capres = capsys.readouterr()
assert capres.err == ""
assert "optimizer: rewrite test-reason replaces Op1.0 with var1" in capres.out
class TestFeature(Feature):
def validate(self, *args):
raise Exception()
fg.attach_feature(TestFeature())
with pytest.raises(Exception):
rv_feature.replace_all_validate(
fg, [(var3, var1)], reason="test-reason", verbose=True
)
capres = capsys.readouterr()
assert "optimizer: validate failed on node Op1.0" in capres.out
...@@ -256,6 +256,19 @@ class TestFunctionGraph: ...@@ -256,6 +256,19 @@ class TestFunctionGraph:
assert fg.apply_nodes == {var4.owner, var5.owner} assert fg.apply_nodes == {var4.owner, var5.owner}
assert var4.owner.inputs == [var1, var2] assert var4.owner.inputs == [var1, var2]
def test_replace_verbose(self, capsys):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
fg = FunctionGraph([var1, var2], [var3], clone=False)
fg.replace(var3, var1, reason="test-reason", verbose=True)
capres = capsys.readouterr()
assert capres.err == ""
assert "optimizer: rewrite test-reason replaces Op1.0 with var1" in capres.out
def test_replace_circular(self): def test_replace_circular(self):
"""`FunctionGraph` allows cycles--for better or worse.""" """`FunctionGraph` allows cycles--for better or worse."""
......
...@@ -7,12 +7,14 @@ from aesara.graph.fg import FunctionGraph ...@@ -7,12 +7,14 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
LocalOptGroup,
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpSub, OpSub,
PatternSub, PatternSub,
TopoOptimizer, TopoOptimizer,
aesara, aesara,
local_optimizer,
logging, logging,
pre_constant_merge, pre_constant_merge,
pre_greedy_local_optimizer, pre_greedy_local_optimizer,
...@@ -698,3 +700,42 @@ def test_patternsub_invalid_dtype(out_pattern): ...@@ -698,3 +700,42 @@ def test_patternsub_invalid_dtype(out_pattern):
) )
opt.optimize(fg) opt.optimize(fg)
assert fg.apply_nodes.pop().op == op_cast_type2 assert fg.apply_nodes.pop().op == op_cast_type2
class TestLocalOptGroup:
def test_optimizer_verbose(self, capsys):
x = MyVariable("x")
y = MyVariable("y")
o1 = op1(x, y)
fgraph = FunctionGraph([x, y], [o1], clone=False)
@local_optimizer(None)
def local_opt_1(fgraph, node):
if node.inputs[0] == x:
res = op2(y, *node.inputs[1:])
return [res]
@local_optimizer(None)
def local_opt_2(fgraph, node):
if node.inputs[0] == y:
res = op2(x, *node.inputs[1:])
return [res]
opt_group = LocalOptGroup(local_opt_1, local_opt_2)
with config.change_flags(optimizer_verbose=True):
(new_res,) = opt_group.transform(fgraph, o1.owner)
_ = opt_group.transform(fgraph, new_res.owner)
capres = capsys.readouterr()
assert capres.err == ""
assert (
"optimizer: rewrite local_opt_1 replaces Op1(x, y) with [Op2.0]"
in capres.out
)
assert (
"optimizer: rewrite local_opt_2 replaces Op2(y, y) with [Op2.0]"
in capres.out
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论