提交 746eecbc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename optimize_graph to rewrite_graph

上级 2dc0af2f
...@@ -14,7 +14,7 @@ from aesara.graph.op import Op ...@@ -14,7 +14,7 @@ from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import node_rewriter, graph_rewriter from aesara.graph.opt import node_rewriter, graph_rewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery from aesara.graph.optdb import RewriteDatabaseQuery
# isort: on # isort: on
import copy import copy
from typing import Generator, Sequence, Union, cast import warnings
from typing import TYPE_CHECKING, Generator, Optional, Sequence, Union, cast
import aesara import aesara
from aesara.graph.basic import ( from aesara.graph.basic import (
...@@ -13,46 +14,72 @@ from aesara.graph.fg import FunctionGraph ...@@ -13,46 +14,72 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import RewriteDatabaseQuery from aesara.graph.optdb import RewriteDatabaseQuery
def optimize_graph( if TYPE_CHECKING:
fgraph: Union[Variable, FunctionGraph], from aesara.graph.opt import GraphRewriter
include: Sequence[str] = ["canonicalize"],
custom_opt=None,
def rewrite_graph(
graph: Union[Variable, Sequence[Variable], FunctionGraph],
include: Sequence[str] = ("canonicalize",),
custom_rewrite: Optional["GraphRewriter"] = None,
clone: bool = False, clone: bool = False,
**kwargs custom_opt: Optional["GraphRewriter"] = None,
) -> Union[Variable, FunctionGraph]: **kwargs,
"""Easily optimize a graph. ) -> Union[Variable, Sequence[Variable], FunctionGraph]:
"""Easily apply rewrites to a graph.
Parameters Parameters
========== ----------
fgraph: graph
A ``FunctionGraph`` or ``Variable`` to be optimized. A `FunctionGraph` or `Variable` to be rewritten.
include: include
String names of the optimizations to be applied. The default String names of the rewrites to be queried, via a
optimization is ``"canonicalization"``. `RewriteDatabaseQuery` instance, and applied. The default rewrite
custom_opt: query string is ``"canonicalization"``.
A custom ``Optimization`` to also be applied. custom_rewrite
clone: A custom `Rewriter` to also be applied.
Whether or not to clone the input graph before optimizing. clone
**kwargs: Whether or not to clone the input graph before rewriting.
Keyword arguments passed to the ``aesara.graph.optdb.RewriteDatabaseQuery`` object. **kwargs
Keyword arguments passed to a `RewriteDatabaseQuery` object.
""" """
from aesara.compile import optdb from aesara.compile import optdb
return_only_out = False return_fgraph = False
if not isinstance(fgraph, FunctionGraph): if isinstance(graph, FunctionGraph):
fgraph = FunctionGraph(outputs=[fgraph], clone=clone) outputs: Sequence[Variable] = graph.outputs
return_only_out = True fgraph = graph
return_fgraph = True
else:
if isinstance(graph, (list, tuple)):
outputs = graph
else:
assert isinstance(graph, Variable)
outputs = [graph]
fgraph = FunctionGraph(outputs=outputs, clone=clone)
query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = query_rewrites.rewrite(fgraph)
canonicalize_opt = optdb.query(RewriteDatabaseQuery(include=include, **kwargs)) if custom_opt is not None:
_ = canonicalize_opt.rewrite(fgraph) warnings.warn(
"`custom_opt` is deprecated; use `custom_rewrite` instead.",
DeprecationWarning,
stacklevel=2,
)
custom_rewrite = custom_opt
if custom_opt: if custom_rewrite:
custom_opt.rewrite(fgraph) custom_rewrite.rewrite(fgraph)
if return_only_out: if return_fgraph:
return fgraph.outputs[0]
else:
return fgraph return fgraph
else:
if isinstance(graph, (list, tuple)):
return fgraph.outputs
else:
return fgraph.outputs[0]
def is_same_graph_with_merge(var1, var2, givens=None): def is_same_graph_with_merge(var1, var2, givens=None):
...@@ -81,7 +108,7 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -81,7 +108,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
# Perform merge optimization. # Perform merge optimization.
MergeOptimizer().rewrite(fgraph) MergeOptimizer().rewrite(fgraph)
# When two variables perform the same computations, they will have the same # When two variables perform the same computations, they will have the same
# owner in the optimized graph. # owner in the rewritten graph.
# We need to be careful with the special case where the owner is None, # We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable. # which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in # We also need to make sure we replace a Variable if it is present in
...@@ -221,3 +248,28 @@ def get_clients_at_depth( ...@@ -221,3 +248,28 @@ def get_clients_at_depth(
else: else:
assert var.owner is not None assert var.owner is not None
yield var.owner yield var.owner
DEPRECATED_NAMES = [
(
"optimize_graph",
"`optimize_graph` is deprecated: use `rewrite_graph` instead.",
rewrite_graph,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -24,7 +24,7 @@ from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefin ...@@ -24,7 +24,7 @@ from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefin
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
...@@ -1336,7 +1336,7 @@ def infer_broadcastable(shape): ...@@ -1336,7 +1336,7 @@ def infer_broadcastable(shape):
features=[ShapeFeature()], features=[ShapeFeature()],
clone=True, clone=True,
) )
folded_shape = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape) bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast return sh, bcast
......
...@@ -1433,10 +1433,10 @@ class ShapeFeature(Feature): ...@@ -1433,10 +1433,10 @@ class ShapeFeature(Feature):
clone=True, clone=True,
# copy_inputs=False, # copy_inputs=False,
) )
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
canon_shapes = optimize_graph( canon_shapes = rewrite_graph(
shapes_fg, custom_opt=topo_constant_folding shapes_fg, custom_rewrite=topo_constant_folding
).outputs ).outputs
sx = canon_shapes[: len(sx)] sx = canon_shapes[: len(sx)]
......
...@@ -446,7 +446,7 @@ The following is an example that distributes dot products across additions. ...@@ -446,7 +446,7 @@ The following is an example that distributes dot products across additions.
import aesara.tensor as at import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import EquilibriumGraphRewriter from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.tensor.math import _dot from aesara.tensor.math import _dot
from etuples import etuple from etuples import etuple
from kanren import conso, eq, fact, heado, tailo from kanren import conso, eq, fact, heado, tailo
...@@ -499,7 +499,7 @@ Below, we apply `dot_distribute_rewrite` to a few example graphs. First we crea ...@@ -499,7 +499,7 @@ Below, we apply `dot_distribute_rewrite` to a few example graphs. First we crea
Next we apply the rewrite to the graph: Next we apply the rewrite to the graph:
>>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_rewrite, clone=False) >>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False)
>>> print(aesara.pprint(res)) >>> print(aesara.pprint(res))
((A @ x) + (A @ y)) ((A @ x) + (A @ y))
...@@ -511,7 +511,7 @@ few more test cases: ...@@ -511,7 +511,7 @@ few more test cases:
>>> test_at = A_at.dot((x_at + y_at) + (z_at + w_at)) >>> test_at = A_at.dot((x_at + y_at) + (z_at + w_at))
>>> print(aesara.pprint(test_at)) >>> print(aesara.pprint(test_at))
(A @ ((x + y) + (z + w))) (A @ ((x + y) + (z + w)))
>>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_rewrite, clone=False) >>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False)
>>> print(aesara.pprint(res)) >>> print(aesara.pprint(res))
(((A @ x) + (A @ y)) + ((A @ z) + (A @ w))) (((A @ x) + (A @ y)) + ((A @ z) + (A @ w)))
...@@ -520,7 +520,7 @@ few more test cases: ...@@ -520,7 +520,7 @@ few more test cases:
>>> test_at = A_at.dot(x_at + (y_at + B_at.dot(z_at + w_at))) >>> test_at = A_at.dot(x_at + (y_at + B_at.dot(z_at + w_at)))
>>> print(aesara.pprint(test_at)) >>> print(aesara.pprint(test_at))
(A @ (x + (y + ((B @ z) + (B @ w))))) (A @ (x + (y + ((B @ z) + (B @ w)))))
>>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_rewrite, clone=False) >>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False)
>>> print(aesara.pprint(res)) >>> print(aesara.pprint(res))
((A @ x) + ((A @ y) + ((A @ (B @ z)) + (A @ (B @ w))))) ((A @ x) + ((A @ y) + ((A @ (B @ z)) + (A @ (B @ w)))))
...@@ -533,7 +533,7 @@ To do that, we will create another :class:`Rewriter` that simply reverses the ar ...@@ -533,7 +533,7 @@ To do that, we will create another :class:`Rewriter` that simply reverses the ar
to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``: to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``:
>>> dot_gather_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10) >>> dot_gather_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_rewrite, clone=False) >>> rev_res = rewrite_graph(res, include=[], custom_rewrite=dot_gather_rewrite, clone=False)
>>> print(aesara.pprint(rev_res)) >>> print(aesara.pprint(rev_res))
(A @ (x + (y + (B @ (z + w))))) (A @ (x + (y + (B @ (z + w)))))
......
...@@ -12,7 +12,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad ...@@ -12,7 +12,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.basic import equal_computations from aesara.graph.basic import equal_computations
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.printing import debugprint from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor from aesara.tensor.basic import as_tensor
...@@ -455,7 +455,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -455,7 +455,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
op_var = op_graph(x, y, z) op_var = op_graph(x, y, z)
fg = FunctionGraph(outputs=[op_var[1]], clone=False) fg = FunctionGraph(outputs=[op_var[1]], clone=False)
opt_res = optimize_graph(fg, custom_opt=ShapeOptimizer()) opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer())
assert opt_res.shape_feature.shape_of[x] is None assert opt_res.shape_feature.shape_of[x] is None
assert opt_res.shape_feature.shape_of[z][0].data == 2 assert opt_res.shape_feature.shape_of[z][0].data == 2
......
...@@ -14,7 +14,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -14,7 +14,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.kanren import KanrenRelationSub from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import EquilibriumGraphRewriter from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.unify import eval_if_etuple from aesara.graph.unify import eval_if_etuple
from aesara.tensor.math import Dot, _dot from aesara.tensor.math import Dot, _dot
from tests.graph.utils import MyType, MyVariable from tests.graph.utils import MyType, MyVariable
...@@ -155,7 +155,7 @@ def test_KanrenRelationSub_dot(): ...@@ -155,7 +155,7 @@ def test_KanrenRelationSub_dot():
[KanrenRelationSub(distributes)], max_use_ratio=10 [KanrenRelationSub(distributes)], max_use_ratio=10
) )
fgraph_opt = optimize_graph(fgraph, custom_opt=distribute_opt) fgraph_opt = rewrite_graph(fgraph, custom_rewrite=distribute_opt)
(expr_opt,) = fgraph_opt.outputs (expr_opt,) = fgraph_opt.outputs
assert expr_opt.owner.op == at.add assert expr_opt.owner.op == at.add
......
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import optimizer from aesara.graph.opt import graph_rewriter
from aesara.graph.opt_utils import is_same_graph, optimize_graph from aesara.graph.opt_utils import is_same_graph, rewrite_graph
from aesara.tensor.math import neg from aesara.tensor.math import neg
from aesara.tensor.type import vectors from aesara.tensor.type import vectors
...@@ -139,20 +139,20 @@ class TestIsSameGraph: ...@@ -139,20 +139,20 @@ class TestIsSameGraph:
) )
def test_optimize_graph(): def test_rewrite_graph():
x, y = vectors("xy") x, y = vectors("xy")
@optimizer @graph_rewriter
def custom_opt(fgraph): def custom_rewrite(fgraph):
fgraph.replace(x, y, import_missing=True) fgraph.replace(x, y, import_missing=True)
x_opt = optimize_graph(x, custom_opt=custom_opt) x_rewritten = rewrite_graph(x, custom_rewrite=custom_rewrite)
assert x_opt is y assert x_rewritten is y
x_opt = optimize_graph( x_rewritten = rewrite_graph(
FunctionGraph(outputs=[x], clone=False), custom_opt=custom_opt FunctionGraph(outputs=[x], clone=False), custom_rewrite=custom_rewrite
) )
assert x_opt.outputs[0] is y assert x_rewritten.outputs[0] is y
...@@ -17,7 +17,7 @@ from aesara.graph.basic import Apply, Constant, Variable ...@@ -17,7 +17,7 @@ from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, node_rewriter, out2in from aesara.graph.opt import check_stack_trace, node_rewriter, out2in
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -1817,7 +1817,7 @@ class TestUselessCheckAndRaise: ...@@ -1817,7 +1817,7 @@ class TestUselessCheckAndRaise:
"""Remove `CheckAndRaise`s when all the conditions are always true.""" """Remove `CheckAndRaise`s when all the conditions are always true."""
x = scalar() x = scalar()
fg = FunctionGraph(outputs=[assert_op(x, 1)], clone=False) fg = FunctionGraph(outputs=[assert_op(x, 1)], clone=False)
fg_res = optimize_graph(fg, include=["canonicalize", "specialize"]) fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort() topo = fg_res.toposort()
assert not any(isinstance(node.op, CheckAndRaise) for node in topo) assert not any(isinstance(node.op, CheckAndRaise) for node in topo)
...@@ -1826,7 +1826,7 @@ class TestUselessCheckAndRaise: ...@@ -1826,7 +1826,7 @@ class TestUselessCheckAndRaise:
x = scalar() x = scalar()
y = scalar() y = scalar()
fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False) fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False)
fg_res = optimize_graph(fg, include=["canonicalize", "specialize"]) fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort() topo = fg_res.toposort()
(assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)] (assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)]
assert assert_node.inputs == [x, y] assert assert_node.inputs == [x, y]
...@@ -1836,7 +1836,7 @@ class TestUselessCheckAndRaise: ...@@ -1836,7 +1836,7 @@ class TestUselessCheckAndRaise:
x = scalar() x = scalar()
y = scalar() y = scalar()
fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False) fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False)
fg_res = optimize_graph(fg, include=["canonicalize", "specialize"]) fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort() topo = fg_res.toposort()
(assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)] (assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)]
assert assert_node.inputs[:2] == [x, y] assert assert_node.inputs[:2] == [x, y]
...@@ -3017,7 +3017,7 @@ def test_local_Shape_of_SpecifyShape(shape): ...@@ -3017,7 +3017,7 @@ def test_local_Shape_of_SpecifyShape(shape):
s = specify_shape(x, shape).shape s = specify_shape(x, shape).shape
fgraph = FunctionGraph(outputs=[s], clone=False) fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False) _ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables assert x not in fgraph.variables
assert shape in fgraph.variables assert shape in fgraph.variables
...@@ -3034,7 +3034,7 @@ def test_local_Shape_of_SpecifyShape_partial(s1): ...@@ -3034,7 +3034,7 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
fgraph = FunctionGraph(outputs=[s], clone=False) fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = optimize_graph(fgraph, clone=False) _ = rewrite_graph(fgraph, clone=False)
assert x in fgraph.variables assert x in fgraph.variables
assert s1 in fgraph.variables assert s1 in fgraph.variables
...@@ -3046,7 +3046,7 @@ def test_local_Shape_i_of_broadcastable(): ...@@ -3046,7 +3046,7 @@ def test_local_Shape_i_of_broadcastable():
s = Shape_i(1)(x) s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False) fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False) _ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1 assert fgraph.outputs[0].data == 1
...@@ -3067,7 +3067,7 @@ def test_local_Shape_i_of_broadcastable(): ...@@ -3067,7 +3067,7 @@ def test_local_Shape_i_of_broadcastable():
x = MyVariable(MyType(), None, None) x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x) s = Shape_i(0)(x)
fgraph = FunctionGraph(outputs=[s], clone=False) fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False) _ = rewrite_graph(fgraph, clone=False)
assert fgraph.outputs[0] == s assert fgraph.outputs[0] == s
...@@ -3197,7 +3197,7 @@ def test_local_Unique_scalar(return_index, return_counts, return_inverse): ...@@ -3197,7 +3197,7 @@ def test_local_Unique_scalar(return_index, return_counts, return_inverse):
) )
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"] y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"]
) )
y_rewritten = y_rewritten_fg.outputs[0] y_rewritten = y_rewritten_fg.outputs[0]
...@@ -3243,7 +3243,7 @@ def test_local_Unique_Alloc_lift( ...@@ -3243,7 +3243,7 @@ def test_local_Unique_Alloc_lift(
# This approach allows us to directly confirm that `x` is in the result. # This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_Unique_Alloc_lift"], include=["canonicalize", "local_Unique_Alloc_lift"],
...@@ -3301,7 +3301,7 @@ def test_local_Unique_BroadcastTo( ...@@ -3301,7 +3301,7 @@ def test_local_Unique_BroadcastTo(
# This approach allows us to directly confirm that `x` is in the result. # This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"], include=["canonicalize", "local_Unique_BroadcastTo_lift"],
...@@ -3364,7 +3364,7 @@ def test_local_Unique_Repeat( ...@@ -3364,7 +3364,7 @@ def test_local_Unique_Repeat(
# This approach allows us to directly confirm that `x` is in the result. # This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_Unique_Repeat_lift"], include=["canonicalize", "local_Unique_Repeat_lift"],
...@@ -3420,7 +3420,7 @@ def test_local_Unique_second( ...@@ -3420,7 +3420,7 @@ def test_local_Unique_second(
# This approach allows us to directly confirm that `x` is in the result. # This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_Unique_second_lift"], include=["canonicalize", "local_Unique_second_lift"],
...@@ -3466,7 +3466,7 @@ def test_local_merge_consecutive_specify_shape(): ...@@ -3466,7 +3466,7 @@ def test_local_merge_consecutive_specify_shape():
y = specify_shape(specify_shape(x, s), s) y = specify_shape(specify_shape(x, s), s)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_merge_consecutive_specify_shape"], include=["canonicalize", "local_merge_consecutive_specify_shape"],
...@@ -3483,7 +3483,7 @@ def test_local_merge_consecutive_specify_shape2(): ...@@ -3483,7 +3483,7 @@ def test_local_merge_consecutive_specify_shape2():
y = specify_shape(specify_shape(x, [s1, s2, None]), [None, s3, s4]) y = specify_shape(specify_shape(x, [s1, s2, None]), [None, s3, s4])
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_merge_consecutive_specify_shape"], include=["canonicalize", "local_merge_consecutive_specify_shape"],
...@@ -3507,7 +3507,7 @@ def test_local_remove_scalar_BroadcastTo(): ...@@ -3507,7 +3507,7 @@ def test_local_remove_scalar_BroadcastTo():
assert isinstance(y.owner.op, BroadcastTo) assert isinstance(y.owner.op, BroadcastTo)
res = optimize_graph( res = rewrite_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"] y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
) )
...@@ -3521,7 +3521,7 @@ def test_local_useless_dimshuffle_makevector(): ...@@ -3521,7 +3521,7 @@ def test_local_useless_dimshuffle_makevector():
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=["canonicalize", "local_useless_dimshuffle_makevector"], include=["canonicalize", "local_useless_dimshuffle_makevector"],
...@@ -3544,7 +3544,7 @@ def test_Shape_i_canonicalize(): ...@@ -3544,7 +3544,7 @@ def test_Shape_i_canonicalize():
y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()]) y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()])
y_rewritten_fg = optimize_graph( y_rewritten_fg = rewrite_graph(
y_fg, y_fg,
clone=False, clone=False,
include=[ include=[
...@@ -3686,7 +3686,7 @@ class TestLocalElemwiseAlloc: ...@@ -3686,7 +3686,7 @@ class TestLocalElemwiseAlloc:
z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()]) z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()])
z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"]) z_opt_fg = rewrite_graph(z_fg, clone=False, include=["local_elemwise_alloc"])
assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes) assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes)
def test_remove_alloc_wo_dimshuffle(self): def test_remove_alloc_wo_dimshuffle(self):
......
...@@ -25,7 +25,7 @@ from aesara.graph.opt import ( ...@@ -25,7 +25,7 @@ from aesara.graph.opt import (
in2out, in2out,
out2in, out2in,
) )
from aesara.graph.opt_utils import is_same_graph, optimize_graph from aesara.graph.opt_utils import is_same_graph, rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
...@@ -251,7 +251,7 @@ class TestAlgebraicCanonizer: ...@@ -251,7 +251,7 @@ class TestAlgebraicCanonizer:
], ],
) )
def test_muldiv(self, e, exp_g): def test_muldiv(self, e, exp_g):
g_rewritten = optimize_graph(e, custom_opt=mul_canonizer) g_rewritten = rewrite_graph(e, custom_rewrite=mul_canonizer)
assert equal_computations([g_rewritten], [exp_g]) assert equal_computations([g_rewritten], [exp_g])
def test_elemwise_multiple_inputs_rewrites(self): def test_elemwise_multiple_inputs_rewrites(self):
...@@ -966,8 +966,8 @@ class TestAlgebraicCanonizer: ...@@ -966,8 +966,8 @@ class TestAlgebraicCanonizer:
z.owner.op, z.owner.inputs, [tensor("float64", (None, None))] z.owner.op, z.owner.inputs, [tensor("float64", (None, None))]
).outputs[0] ).outputs[0]
z_rewritten = optimize_graph( z_rewritten = rewrite_graph(
z, custom_opt=in2out(local_mul_canonizer, name="blah") z, custom_rewrite=in2out(local_mul_canonizer, name="blah")
) )
# No rewrite was applied # No rewrite was applied
assert z_rewritten is z assert z_rewritten is z
...@@ -4140,7 +4140,7 @@ def test_local_log_sum_exp_inf(): ...@@ -4140,7 +4140,7 @@ def test_local_log_sum_exp_inf():
def test_local_reciprocal_1_plus_exp(): def test_local_reciprocal_1_plus_exp():
x = vector("x") x = vector("x")
y = at.reciprocal(1 + exp(x)) y = at.reciprocal(1 + exp(x))
z = optimize_graph(y, include=["canonicalization", "stabilize", "specialize"]) z = rewrite_graph(y, include=["canonicalization", "stabilize", "specialize"])
assert z.owner.op == sigmoid assert z.owner.op == sigmoid
......
...@@ -11,7 +11,7 @@ from aesara.compile.ops import DeepCopyOp ...@@ -11,7 +11,7 @@ from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, ancestors from aesara.graph.basic import Constant, Variable, ancestors
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.raise_op import Assert from aesara.raise_op import Assert
...@@ -1907,7 +1907,7 @@ def test_local_subtensor_shape_constant(): ...@@ -1907,7 +1907,7 @@ def test_local_subtensor_shape_constant():
assert res.data == 1 assert res.data == 1
# Make sure it's part of the canonicalizations # Make sure it's part of the canonicalizations
res = optimize_graph(x) res = rewrite_graph(x)
assert isinstance(res, Constant) assert isinstance(res, Constant)
assert res.data == 1 assert res.data == 1
...@@ -2003,7 +2003,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -2003,7 +2003,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val])) y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val]))
# This optimization should appear in the canonicalizations # This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False) y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0: if y.ndim == 0:
# SpecifyShape should be removed altogether # SpecifyShape should be removed altogether
...@@ -2042,7 +2042,7 @@ def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): ...@@ -2042,7 +2042,7 @@ def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y = specify_shape(x, s)[idx] y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations # This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False) y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape) assert not isinstance(y_opt.owner.op, SpecifyShape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论