提交 746eecbc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename optimize_graph to rewrite_graph

上级 2dc0af2f
......@@ -14,7 +14,7 @@ from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import node_rewriter, graph_rewriter
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery
# isort: on
import copy
from typing import Generator, Sequence, Union, cast
import warnings
from typing import TYPE_CHECKING, Generator, Optional, Sequence, Union, cast
import aesara
from aesara.graph.basic import (
......@@ -13,46 +14,72 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import RewriteDatabaseQuery
def optimize_graph(
fgraph: Union[Variable, FunctionGraph],
include: Sequence[str] = ["canonicalize"],
custom_opt=None,
if TYPE_CHECKING:
from aesara.graph.opt import GraphRewriter
def rewrite_graph(
graph: Union[Variable, Sequence[Variable], FunctionGraph],
include: Sequence[str] = ("canonicalize",),
custom_rewrite: Optional["GraphRewriter"] = None,
clone: bool = False,
**kwargs
) -> Union[Variable, FunctionGraph]:
"""Easily optimize a graph.
custom_opt: Optional["GraphRewriter"] = None,
**kwargs,
) -> Union[Variable, Sequence[Variable], FunctionGraph]:
"""Easily apply rewrites to a graph.
Parameters
==========
fgraph:
A ``FunctionGraph`` or ``Variable`` to be optimized.
include:
String names of the optimizations to be applied. The default
optimization is ``"canonicalization"``.
custom_opt:
A custom ``Optimization`` to also be applied.
clone:
Whether or not to clone the input graph before optimizing.
**kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.RewriteDatabaseQuery`` object.
----------
graph
A `FunctionGraph` or `Variable` to be rewritten.
include
String names of the rewrites to be queried, via a
`RewriteDatabaseQuery` instance, and applied. The default rewrite
query string is ``"canonicalization"``.
custom_rewrite
A custom `Rewriter` to also be applied.
clone
Whether or not to clone the input graph before rewriting.
**kwargs
Keyword arguments passed to a `RewriteDatabaseQuery` object.
"""
from aesara.compile import optdb
return_only_out = False
if not isinstance(fgraph, FunctionGraph):
fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
return_only_out = True
return_fgraph = False
if isinstance(graph, FunctionGraph):
outputs: Sequence[Variable] = graph.outputs
fgraph = graph
return_fgraph = True
else:
if isinstance(graph, (list, tuple)):
outputs = graph
else:
assert isinstance(graph, Variable)
outputs = [graph]
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = canonicalize_opt.rewrite(fgraph)
fgraph = FunctionGraph(outputs=outputs, clone=clone)
if custom_opt:
custom_opt.rewrite(fgraph)
query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = query_rewrites.rewrite(fgraph)
if return_only_out:
return fgraph.outputs[0]
else:
if custom_opt is not None:
warnings.warn(
"`custom_opt` is deprecated; use `custom_rewrite` instead.",
DeprecationWarning,
stacklevel=2,
)
custom_rewrite = custom_opt
if custom_rewrite:
custom_rewrite.rewrite(fgraph)
if return_fgraph:
return fgraph
else:
if isinstance(graph, (list, tuple)):
return fgraph.outputs
else:
return fgraph.outputs[0]
def is_same_graph_with_merge(var1, var2, givens=None):
......@@ -81,7 +108,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
# Perform merge optimization.
MergeOptimizer().rewrite(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# owner in the rewritten graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
......@@ -221,3 +248,28 @@ def get_clients_at_depth(
else:
assert var.owner is not None
yield var.owner
DEPRECATED_NAMES = [
(
"optimize_graph",
"`optimize_graph` is deprecated: use `rewrite_graph` instead.",
rewrite_graph,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
......@@ -24,7 +24,7 @@ from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefin
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
......@@ -1336,7 +1336,7 @@ def infer_broadcastable(shape):
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast
......
......@@ -1433,10 +1433,10 @@ class ShapeFeature(Feature):
clone=True,
# copy_inputs=False,
)
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
canon_shapes = optimize_graph(
shapes_fg, custom_opt=topo_constant_folding
canon_shapes = rewrite_graph(
shapes_fg, custom_rewrite=topo_constant_folding
).outputs
sx = canon_shapes[: len(sx)]
......
......@@ -446,7 +446,7 @@ The following is an example that distributes dot products across additions.
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.tensor.math import _dot
from etuples import etuple
from kanren import conso, eq, fact, heado, tailo
......@@ -499,7 +499,7 @@ Below, we apply `dot_distribute_rewrite` to a few example graphs. First we crea
Next we apply the rewrite to the graph:
>>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_rewrite, clone=False)
>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False)
>>> print(aesara.pprint(res))
((A @ x) + (A @ y))
......@@ -511,7 +511,7 @@ few more test cases:
>>> test_at = A_at.dot((x_at + y_at) + (z_at + w_at))
>>> print(aesara.pprint(test_at))
(A @ ((x + y) + (z + w)))
>>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_rewrite, clone=False)
>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False)
>>> print(aesara.pprint(res))
(((A @ x) + (A @ y)) + ((A @ z) + (A @ w)))
......@@ -520,7 +520,7 @@ few more test cases:
>>> test_at = A_at.dot(x_at + (y_at + B_at.dot(z_at + w_at)))
>>> print(aesara.pprint(test_at))
(A @ (x + (y + ((B @ z) + (B @ w)))))
>>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_rewrite, clone=False)
>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False)
>>> print(aesara.pprint(res))
((A @ x) + ((A @ y) + ((A @ (B @ z)) + (A @ (B @ w)))))
......@@ -533,7 +533,7 @@ To do that, we will create another :class:`Rewriter` that simply reverses the ar
to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``:
>>> dot_gather_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_rewrite, clone=False)
>>> rev_res = rewrite_graph(res, include=[], custom_rewrite=dot_gather_rewrite, clone=False)
>>> print(aesara.pprint(rev_res))
(A @ (x + (y + (B @ (z + w)))))
......
......@@ -12,7 +12,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.basic import equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.utils import MissingInputError
from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor
......@@ -455,7 +455,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
op_var = op_graph(x, y, z)
fg = FunctionGraph(outputs=[op_var[1]], clone=False)
opt_res = optimize_graph(fg, custom_opt=ShapeOptimizer())
opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer())
assert opt_res.shape_feature.shape_of[x] is None
assert opt_res.shape_feature.shape_of[z][0].data == 2
......
......@@ -14,7 +14,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.op import Op
from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.unify import eval_if_etuple
from aesara.tensor.math import Dot, _dot
from tests.graph.utils import MyType, MyVariable
......@@ -155,7 +155,7 @@ def test_KanrenRelationSub_dot():
[KanrenRelationSub(distributes)], max_use_ratio=10
)
fgraph_opt = optimize_graph(fgraph, custom_opt=distribute_opt)
fgraph_opt = rewrite_graph(fgraph, custom_rewrite=distribute_opt)
(expr_opt,) = fgraph_opt.outputs
assert expr_opt.owner.op == at.add
......
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import optimizer
from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.graph.opt import graph_rewriter
from aesara.graph.opt_utils import is_same_graph, rewrite_graph
from aesara.tensor.math import neg
from aesara.tensor.type import vectors
......@@ -139,20 +139,20 @@ class TestIsSameGraph:
)
def test_optimize_graph():
def test_rewrite_graph():
x, y = vectors("xy")
@optimizer
def custom_opt(fgraph):
@graph_rewriter
def custom_rewrite(fgraph):
fgraph.replace(x, y, import_missing=True)
x_opt = optimize_graph(x, custom_opt=custom_opt)
x_rewritten = rewrite_graph(x, custom_rewrite=custom_rewrite)
assert x_opt is y
assert x_rewritten is y
x_opt = optimize_graph(
FunctionGraph(outputs=[x], clone=False), custom_opt=custom_opt
x_rewritten = rewrite_graph(
FunctionGraph(outputs=[x], clone=False), custom_rewrite=custom_rewrite
)
assert x_opt.outputs[0] is y
assert x_rewritten.outputs[0] is y
......@@ -17,7 +17,7 @@ from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, node_rewriter, out2in
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
......@@ -1817,7 +1817,7 @@ class TestUselessCheckAndRaise:
"""Remove `CheckAndRaise`s when all the conditions are always true."""
x = scalar()
fg = FunctionGraph(outputs=[assert_op(x, 1)], clone=False)
fg_res = optimize_graph(fg, include=["canonicalize", "specialize"])
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
assert not any(isinstance(node.op, CheckAndRaise) for node in topo)
......@@ -1826,7 +1826,7 @@ class TestUselessCheckAndRaise:
x = scalar()
y = scalar()
fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False)
fg_res = optimize_graph(fg, include=["canonicalize", "specialize"])
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
(assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)]
assert assert_node.inputs == [x, y]
......@@ -1836,7 +1836,7 @@ class TestUselessCheckAndRaise:
x = scalar()
y = scalar()
fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False)
fg_res = optimize_graph(fg, include=["canonicalize", "specialize"])
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
(assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)]
assert assert_node.inputs[:2] == [x, y]
......@@ -3017,7 +3017,7 @@ def test_local_Shape_of_SpecifyShape(shape):
s = specify_shape(x, shape).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert shape in fgraph.variables
......@@ -3034,7 +3034,7 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = optimize_graph(fgraph, clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x in fgraph.variables
assert s1 in fgraph.variables
......@@ -3046,7 +3046,7 @@ def test_local_Shape_i_of_broadcastable():
s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1
......@@ -3067,7 +3067,7 @@ def test_local_Shape_i_of_broadcastable():
x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert fgraph.outputs[0] == s
......@@ -3197,7 +3197,7 @@ def test_local_Unique_scalar(return_index, return_counts, return_inverse):
)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"]
)
y_rewritten = y_rewritten_fg.outputs[0]
......@@ -3243,7 +3243,7 @@ def test_local_Unique_Alloc_lift(
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Alloc_lift"],
......@@ -3301,7 +3301,7 @@ def test_local_Unique_BroadcastTo(
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
......@@ -3364,7 +3364,7 @@ def test_local_Unique_Repeat(
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Repeat_lift"],
......@@ -3420,7 +3420,7 @@ def test_local_Unique_second(
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_second_lift"],
......@@ -3466,7 +3466,7 @@ def test_local_merge_consecutive_specify_shape():
y = specify_shape(specify_shape(x, s), s)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_merge_consecutive_specify_shape"],
......@@ -3483,7 +3483,7 @@ def test_local_merge_consecutive_specify_shape2():
y = specify_shape(specify_shape(x, [s1, s2, None]), [None, s3, s4])
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_merge_consecutive_specify_shape"],
......@@ -3507,7 +3507,7 @@ def test_local_remove_scalar_BroadcastTo():
assert isinstance(y.owner.op, BroadcastTo)
res = optimize_graph(
res = rewrite_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
)
......@@ -3521,7 +3521,7 @@ def test_local_useless_dimshuffle_makevector():
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_useless_dimshuffle_makevector"],
......@@ -3544,7 +3544,7 @@ def test_Shape_i_canonicalize():
y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()])
y_rewritten_fg = optimize_graph(
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=[
......@@ -3686,7 +3686,7 @@ class TestLocalElemwiseAlloc:
z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()])
z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"])
z_opt_fg = rewrite_graph(z_fg, clone=False, include=["local_elemwise_alloc"])
assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes)
def test_remove_alloc_wo_dimshuffle(self):
......
......@@ -25,7 +25,7 @@ from aesara.graph.opt import (
in2out,
out2in,
)
from aesara.graph.opt_utils import is_same_graph, optimize_graph
from aesara.graph.opt_utils import is_same_graph, rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
......@@ -251,7 +251,7 @@ class TestAlgebraicCanonizer:
],
)
def test_muldiv(self, e, exp_g):
g_rewritten = optimize_graph(e, custom_opt=mul_canonizer)
g_rewritten = rewrite_graph(e, custom_rewrite=mul_canonizer)
assert equal_computations([g_rewritten], [exp_g])
def test_elemwise_multiple_inputs_rewrites(self):
......@@ -966,8 +966,8 @@ class TestAlgebraicCanonizer:
z.owner.op, z.owner.inputs, [tensor("float64", (None, None))]
).outputs[0]
z_rewritten = optimize_graph(
z, custom_opt=in2out(local_mul_canonizer, name="blah")
z_rewritten = rewrite_graph(
z, custom_rewrite=in2out(local_mul_canonizer, name="blah")
)
# No rewrite was applied
assert z_rewritten is z
......@@ -4140,7 +4140,7 @@ def test_local_log_sum_exp_inf():
def test_local_reciprocal_1_plus_exp():
x = vector("x")
y = at.reciprocal(1 + exp(x))
z = optimize_graph(y, include=["canonicalization", "stabilize", "specialize"])
z = rewrite_graph(y, include=["canonicalization", "stabilize", "specialize"])
assert z.owner.op == sigmoid
......
......@@ -11,7 +11,7 @@ from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, ancestors
from aesara.graph.opt import check_stack_trace
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type
from aesara.raise_op import Assert
......@@ -1907,7 +1907,7 @@ def test_local_subtensor_shape_constant():
assert res.data == 1
# Make sure it's part of the canonicalizations
res = optimize_graph(x)
res = rewrite_graph(x)
assert isinstance(res, Constant)
assert res.data == 1
......@@ -2003,7 +2003,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val]))
# This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False)
y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether
......@@ -2042,7 +2042,7 @@ def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False)
y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论