提交 62651c4a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add aesara.graph.opt_utils.optimize_graph function

上级 ea070b61
import copy import copy
from typing import Sequence, Union
import aesara import aesara
from aesara.graph.basic import equal_computations, graph_inputs, vars_between from aesara.graph.basic import Variable, equal_computations, graph_inputs, vars_between
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query
def optimize_graph(
fgraph: Union[Variable, FunctionGraph],
include: Sequence[str] = ["canonicalize"],
custom_opt=None,
clone: bool = False,
**kwargs
) -> Union[Variable, FunctionGraph]:
"""Easily optimize a graph.
Parameters
==========
fgraph:
A ``FunctionGraph`` or ``Variable`` to be optimized.
include:
String names of the optimizations to be applied. The default
optimization is ``"canonicalization"``.
custom_opt:
A custom ``Optimization`` to also be applied.
clone:
Whether or not to clone the input graph before optimizing.
**kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.Query`` object.
"""
from aesara.compile import optdb
return_only_out = False
if not isinstance(fgraph, FunctionGraph):
fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
return_only_out = True
canonicalize_opt = optdb.query(Query(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
if custom_opt:
custom_opt.optimize(fgraph)
if return_only_out:
return fgraph.outputs[0]
else:
return fgraph
def is_same_graph_with_merge(var1, var2, givens=None): def is_same_graph_with_merge(var1, var2, givens=None):
......
from aesara.graph.opt_utils import is_same_graph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import optimizer
from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.tensor.math import neg from aesara.tensor.math import neg
from aesara.tensor.type import vectors from aesara.tensor.type import vectors
...@@ -135,3 +137,22 @@ class TestIsSameGraph: ...@@ -135,3 +137,22 @@ class TestIsSameGraph:
), ),
], ],
) )
def test_optimize_graph():
x, y = vectors("xy")
@optimizer
def custom_opt(fgraph):
fgraph.replace(x, y, import_missing=True)
x_opt = optimize_graph(x, custom_opt=custom_opt)
assert x_opt is y
x_opt = optimize_graph(
FunctionGraph(outputs=[x], clone=False), custom_opt=custom_opt
)
assert x_opt.outputs[0] is y
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论