提交 fa74d7e3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.graph.toolbox to aesara.graph.features

上级 4403e131
......@@ -30,9 +30,9 @@ from aesara.compile.ops import OutputGuard, _output_guard
from aesara.configdefaults import config
from aesara.graph.basic import Variable, graph_inputs, io_toposort
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op, ops_with_inner_function
from aesara.graph.toolbox import BadOptimization
from aesara.graph.utils import MethodNotDefined
from aesara.link.basic import Container, LocalLinker
from aesara.link.utils import map_storage, raise_with_op
......
......@@ -30,9 +30,9 @@ from aesara.graph.basic import (
vars_between,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes, is_same_graph
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import ops_with_inner_function
from aesara.graph.toolbox import PreserveVariableAttributes, is_same_graph
from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container
from aesara.link.utils import raise_with_op
......
......@@ -149,7 +149,7 @@ from aesara.gpuarray.type import (
get_context,
move_to_gpu,
)
from aesara.graph import toolbox
from aesara.graph import features
from aesara.graph.basic import Constant, Variable, applys_between, clone_replace
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import (
......@@ -244,7 +244,7 @@ class InputToGpuOptimizer(GlobalOptimizer):
"""
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(features.ReplaceValidate())
def apply(self, fgraph):
for input in fgraph.inputs:
......@@ -305,7 +305,7 @@ class GraphToGPU(GlobalOptimizer):
self.local_optimizers_map = local_optimizers_map
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(features.ReplaceValidate())
def apply(self, fgraph):
mapping = {}
......
......@@ -9,8 +9,8 @@ from collections import OrderedDict, deque
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.features import AlreadyThere, Bookkeeper
from aesara.graph.fg import InconsistencyError
from aesara.graph.toolbox import AlreadyThere, Bookkeeper
from aesara.misc.ordered_set import OrderedSet
......
差异被折叠。
......@@ -8,7 +8,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.toolbox import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string
from aesara.misc.ordered_set import OrderedSet
......@@ -548,16 +548,12 @@ class FunctionGraph(MetaObject):
)
def replace_all(self, pairs: List[Tuple[Variable, Variable]], **kwargs) -> NoReturn:
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
"""Replace variables in the ``FunctionGraph`` according to ``(var, new_var)`` pairs in a list."""
for var, new_var in pairs:
self.replace(var, new_var, **kwargs)
def attach_feature(self, feature: Feature) -> NoReturn:
"""
Adds a graph.toolbox.Feature to this function_graph and triggers its
on_attach callback.
"""
"""Add a ``graph.features.Feature`` to this function graph and trigger its on_attach callback."""
# Filter out literally identical `Feature`s
if feature in self._features:
return # the feature is already present
......
......@@ -31,9 +31,9 @@ from aesara.graph.basic import (
io_toposort,
nodes_constructed,
)
from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import Op
from aesara.graph.toolbox import Feature, NodeFinder
from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet
from aesara.utils import flatten
......
差异被折叠。
......@@ -600,7 +600,7 @@ class CondMerge(GlobalOptimizer):
""" Graph Optimizer that merges different cond ops """
def add_requirements(self, fgraph):
from aesara.graph.toolbox import ReplaceValidate
from aesara.graph.features import ReplaceValidate
fgraph.add_feature(ReplaceValidate())
......
......@@ -70,9 +70,9 @@ from aesara.graph.basic import (
graph_inputs,
io_connection_pattern,
)
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op, ops_with_inner_function
from aesara.graph.toolbox import NoOutputFromInplace
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
......
......@@ -73,11 +73,11 @@ from aesara.graph.basic import (
is_in_ancestors,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.toolbox import ReplaceValidate
from aesara.scan.op import Scan
from aesara.scan.utils import (
compress_outs,
......
......@@ -18,7 +18,7 @@ from aesara import compile
from aesara.assert_op import Assert, assert_op
from aesara.compile.ops import ViewOp
from aesara.configdefaults import config
from aesara.graph import toolbox
from aesara.graph import features
from aesara.graph.basic import (
Constant,
Variable,
......@@ -797,7 +797,7 @@ class MakeVectorPrinter:
pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(toolbox.Feature):
class ShapeFeature(features.Feature):
"""Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with
......@@ -4674,7 +4674,7 @@ class FusionOptimizer(GlobalOptimizer):
self.optimizer = local_optimizer
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(features.ReplaceValidate())
def apply(self, fgraph):
did_something = True
......
......@@ -146,6 +146,7 @@ import aesara.scalar
from aesara.compile.mode import optdb
from aesara.configdefaults import config
from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op
from aesara.graph.opt import (
......@@ -157,7 +158,6 @@ from aesara.graph.opt import (
)
from aesara.graph.optdb import SequenceDB
from aesara.graph.params_type import ParamsType
from aesara.graph.toolbox import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.utils import MethodNotDefined, TestValueError
from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t
......
......@@ -123,7 +123,7 @@ simplification described above:
import aesara
from aesara.graph.opt import GlobalOptimizer
from aesara.graph.toolbox import ReplaceValidate
from aesara.graph.features import ReplaceValidate
class Simplify(GlobalOptimizer):
def add_requirements(self, fgraph):
......@@ -149,12 +149,12 @@ simplification described above:
Here's how it works: first, in ``add_requirements``, we add the
``ReplaceValidate`` :ref:`libdoc_graph_fgraphfeature` located in
:ref:`libdoc_graph_toolbox`. This feature adds the ``replace_validate``
:ref:`libdoc_graph_features`. This feature adds the ``replace_validate``
method to ``fgraph``, which is an enhanced version of ``replace`` that
does additional checks to ensure that we are not messing up the
computation graph (note: if ``ReplaceValidate`` was already added by
another optimizer, ``extend`` will do nothing). In a nutshell,
``toolbox.ReplaceValidate`` grants access to ``fgraph.replace_validate``,
``features.ReplaceValidate`` grants access to ``fgraph.replace_validate``,
and ``fgraph.replace_validate`` allows us to replace a Variable with
another while respecting certain validation constraints. You can
browse the list of :ref:`libdoc_graph_fgraphfeaturelist` and see if some of
......
.. _libdoc_graph_toolbox:
.. _libdoc_graph_features:
================================================
:mod:`toolbox` -- [doc TODO]
:mod:`features` -- [doc TODO]
================================================
.. module:: aesara.graph.toolbox
.. module:: aesara.graph.features
:platform: Unix, Windows
:synopsis: Aesara Internals
.. moduleauthor:: LISA
......
......@@ -33,7 +33,7 @@ FunctionGraph
FunctionGraph Features
----------------------
.. autoclass:: aesara.graph.toolbox.Feature
.. autoclass:: aesara.graph.features.Feature
:members:
.. _libdoc_graph_fgraphfeaturelist:
......
......@@ -15,7 +15,7 @@
graph
fgraph
toolbox
features
op
type
params_type
......
......@@ -4,6 +4,7 @@ ignore = E203,E231,E501,E741,W503,W504,C901
max-line-length = 88
per-file-ignores =
**/__init__.py:F401,E402,F403
aesara/graph/toolbox.py:E402,F403,F401
aesara/link/jax/jax_dispatch.py:E402,F403,F401
aesara/link/jax/jax_linker.py:E402,F403,F401
aesara/sparse/sandbox/sp2.py:F401
......
......@@ -15,10 +15,10 @@ from aesara.compile.debugmode import (
)
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.features import BadOptimization
from aesara.graph.op import COp, Op
from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB
from aesara.graph.toolbox import BadOptimization
from aesara.tensor.math import add, dot, log
from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, vector
from tests import unittest_tools as utt
......
......@@ -2,7 +2,7 @@ import pytest
import aesara
from aesara.compile.mode import AddFeatureOptimizer, Mode
from aesara.graph.toolbox import NoOutputFromInplace
from aesara.graph.features import NoOutputFromInplace
from aesara.tensor.math import dot, tanh
from aesara.tensor.type import matrix
......
......@@ -5,6 +5,7 @@ import pytest
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, clone
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import Op
from aesara.graph.opt import (
......@@ -14,7 +15,6 @@ from aesara.graph.opt import (
PatternSub,
TopoOptimizer,
)
from aesara.graph.toolbox import ReplaceValidate
from aesara.graph.type import Type
from tests.unittest_tools import assertFailure_fast
......
from aesara.graph.basic import Apply, Variable
from aesara.graph.features import NodeFinder, is_same_graph
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.toolbox import NodeFinder, is_same_graph
from aesara.graph.type import Type
from aesara.tensor.math import neg
from aesara.tensor.type import vectors
......
......@@ -17,10 +17,10 @@ from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.features import is_same_graph
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in
from aesara.graph.optdb import Query
from aesara.graph.toolbox import is_same_graph
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch
......
......@@ -12,8 +12,8 @@ import aesara.tensor.basic as aet
from aesara.compile import DeepCopyOp, shared
from aesara.compile.io import In
from aesara.configdefaults import config
from aesara.graph.features import is_same_graph
from aesara.graph.op import get_test_value
from aesara.graph.toolbox import is_same_graph
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import exp, isinf
from aesara.tensor.math import sum as aet_sum
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论