提交 fa74d7e3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.graph.toolbox to aesara.graph.features

上级 4403e131
...@@ -30,9 +30,9 @@ from aesara.compile.ops import OutputGuard, _output_guard ...@@ -30,9 +30,9 @@ from aesara.compile.ops import OutputGuard, _output_guard
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Variable, graph_inputs, io_toposort from aesara.graph.basic import Variable, graph_inputs, io_toposort
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op, ops_with_inner_function from aesara.graph.op import COp, Op, ops_with_inner_function
from aesara.graph.toolbox import BadOptimization
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.basic import Container, LocalLinker from aesara.link.basic import Container, LocalLinker
from aesara.link.utils import map_storage, raise_with_op from aesara.link.utils import map_storage, raise_with_op
......
...@@ -30,9 +30,9 @@ from aesara.graph.basic import ( ...@@ -30,9 +30,9 @@ from aesara.graph.basic import (
vars_between, vars_between,
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes, is_same_graph
from aesara.graph.fg import FunctionGraph, InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import ops_with_inner_function from aesara.graph.op import ops_with_inner_function
from aesara.graph.toolbox import PreserveVariableAttributes, is_same_graph
from aesara.graph.utils import get_variable_trace_string from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container from aesara.link.basic import Container
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
......
...@@ -149,7 +149,7 @@ from aesara.gpuarray.type import ( ...@@ -149,7 +149,7 @@ from aesara.gpuarray.type import (
get_context, get_context,
move_to_gpu, move_to_gpu,
) )
from aesara.graph import toolbox from aesara.graph import features
from aesara.graph.basic import Constant, Variable, applys_between, clone_replace from aesara.graph.basic import Constant, Variable, applys_between, clone_replace
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import ( from aesara.graph.opt import (
...@@ -244,7 +244,7 @@ class InputToGpuOptimizer(GlobalOptimizer): ...@@ -244,7 +244,7 @@ class InputToGpuOptimizer(GlobalOptimizer):
""" """
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(features.ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
for input in fgraph.inputs: for input in fgraph.inputs:
...@@ -305,7 +305,7 @@ class GraphToGPU(GlobalOptimizer): ...@@ -305,7 +305,7 @@ class GraphToGPU(GlobalOptimizer):
self.local_optimizers_map = local_optimizers_map self.local_optimizers_map = local_optimizers_map
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(features.ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
mapping = {} mapping = {}
......
...@@ -9,8 +9,8 @@ from collections import OrderedDict, deque ...@@ -9,8 +9,8 @@ from collections import OrderedDict, deque
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.features import AlreadyThere, Bookkeeper
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.toolbox import AlreadyThere, Bookkeeper
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
......
差异被折叠。
...@@ -8,7 +8,7 @@ from aesara.configdefaults import config ...@@ -8,7 +8,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, applys_between from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.toolbox import AlreadyThere, Feature, ReplaceValidate from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
...@@ -548,16 +548,12 @@ class FunctionGraph(MetaObject): ...@@ -548,16 +548,12 @@ class FunctionGraph(MetaObject):
) )
def replace_all(self, pairs: List[Tuple[Variable, Variable]], **kwargs) -> NoReturn: def replace_all(self, pairs: List[Tuple[Variable, Variable]], **kwargs) -> NoReturn:
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list.""" """Replace variables in the ``FunctionGraph`` according to ``(var, new_var)`` pairs in a list."""
for var, new_var in pairs: for var, new_var in pairs:
self.replace(var, new_var, **kwargs) self.replace(var, new_var, **kwargs)
def attach_feature(self, feature: Feature) -> NoReturn: def attach_feature(self, feature: Feature) -> NoReturn:
""" """Add a ``graph.features.Feature`` to this function graph and trigger its on_attach callback."""
Adds a graph.toolbox.Feature to this function_graph and triggers its
on_attach callback.
"""
# Filter out literally identical `Feature`s # Filter out literally identical `Feature`s
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
......
...@@ -31,9 +31,9 @@ from aesara.graph.basic import ( ...@@ -31,9 +31,9 @@ from aesara.graph.basic import (
io_toposort, io_toposort,
nodes_constructed, nodes_constructed,
) )
from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.toolbox import Feature, NodeFinder
from aesara.graph.utils import AssocList from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
from aesara.utils import flatten from aesara.utils import flatten
......
差异被折叠。
...@@ -600,7 +600,7 @@ class CondMerge(GlobalOptimizer): ...@@ -600,7 +600,7 @@ class CondMerge(GlobalOptimizer):
""" Graph Optimizer that merges different cond ops """ """ Graph Optimizer that merges different cond ops """
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
from aesara.graph.toolbox import ReplaceValidate from aesara.graph.features import ReplaceValidate
fgraph.add_feature(ReplaceValidate()) fgraph.add_feature(ReplaceValidate())
......
...@@ -70,9 +70,9 @@ from aesara.graph.basic import ( ...@@ -70,9 +70,9 @@ from aesara.graph.basic import (
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
) )
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import MissingInputError from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op, ops_with_inner_function from aesara.graph.op import Op, ops_with_inner_function
from aesara.graph.toolbox import NoOutputFromInplace
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
......
...@@ -73,11 +73,11 @@ from aesara.graph.basic import ( ...@@ -73,11 +73,11 @@ from aesara.graph.basic import (
is_in_ancestors, is_in_ancestors,
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.toolbox import ReplaceValidate
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import ( from aesara.scan.utils import (
compress_outs, compress_outs,
......
...@@ -18,7 +18,7 @@ from aesara import compile ...@@ -18,7 +18,7 @@ from aesara import compile
from aesara.assert_op import Assert, assert_op from aesara.assert_op import Assert, assert_op
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import toolbox from aesara.graph import features
from aesara.graph.basic import ( from aesara.graph.basic import (
Constant, Constant,
Variable, Variable,
...@@ -797,7 +797,7 @@ class MakeVectorPrinter: ...@@ -797,7 +797,7 @@ class MakeVectorPrinter:
pprint.assign(MakeVector, MakeVectorPrinter()) pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(toolbox.Feature): class ShapeFeature(features.Feature):
"""Graph optimizer for removing all calls to shape(). """Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with This optimizer replaces all Shapes and Subtensors of Shapes with
...@@ -4674,7 +4674,7 @@ class FusionOptimizer(GlobalOptimizer): ...@@ -4674,7 +4674,7 @@ class FusionOptimizer(GlobalOptimizer):
self.optimizer = local_optimizer self.optimizer = local_optimizer
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(features.ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
did_something = True did_something = True
......
...@@ -146,6 +146,7 @@ import aesara.scalar ...@@ -146,6 +146,7 @@ import aesara.scalar
from aesara.compile.mode import optdb from aesara.compile.mode import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, view_roots from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
from aesara.graph.opt import ( from aesara.graph.opt import (
...@@ -157,7 +158,6 @@ from aesara.graph.opt import ( ...@@ -157,7 +158,6 @@ from aesara.graph.opt import (
) )
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import SequenceDB
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.toolbox import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.utils import MethodNotDefined, TestValueError from aesara.graph.utils import MethodNotDefined, TestValueError
from aesara.printing import FunctionPrinter, debugprint, pprint from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
......
...@@ -123,7 +123,7 @@ simplification described above: ...@@ -123,7 +123,7 @@ simplification described above:
import aesara import aesara
from aesara.graph.opt import GlobalOptimizer from aesara.graph.opt import GlobalOptimizer
from aesara.graph.toolbox import ReplaceValidate from aesara.graph.features import ReplaceValidate
class Simplify(GlobalOptimizer): class Simplify(GlobalOptimizer):
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
...@@ -149,12 +149,12 @@ simplification described above: ...@@ -149,12 +149,12 @@ simplification described above:
Here's how it works: first, in ``add_requirements``, we add the Here's how it works: first, in ``add_requirements``, we add the
``ReplaceValidate`` :ref:`libdoc_graph_fgraphfeature` located in ``ReplaceValidate`` :ref:`libdoc_graph_fgraphfeature` located in
:ref:`libdoc_graph_toolbox`. This feature adds the ``replace_validate`` :ref:`libdoc_graph_features`. This feature adds the ``replace_validate``
method to ``fgraph``, which is an enhanced version of ``replace`` that method to ``fgraph``, which is an enhanced version of ``replace`` that
does additional checks to ensure that we are not messing up the does additional checks to ensure that we are not messing up the
computation graph (note: if ``ReplaceValidate`` was already added by computation graph (note: if ``ReplaceValidate`` was already added by
another optimizer, ``extend`` will do nothing). In a nutshell, another optimizer, ``extend`` will do nothing). In a nutshell,
``toolbox.ReplaceValidate`` grants access to ``fgraph.replace_validate``, ``features.ReplaceValidate`` grants access to ``fgraph.replace_validate``,
and ``fgraph.replace_validate`` allows us to replace a Variable with and ``fgraph.replace_validate`` allows us to replace a Variable with
another while respecting certain validation constraints. You can another while respecting certain validation constraints. You can
browse the list of :ref:`libdoc_graph_fgraphfeaturelist` and see if some of browse the list of :ref:`libdoc_graph_fgraphfeaturelist` and see if some of
......
.. _libdoc_graph_toolbox: .. _libdoc_graph_features:
================================================ ================================================
:mod:`toolbox` -- [doc TODO] :mod:`features` -- [doc TODO]
================================================ ================================================
.. module:: aesara.graph.toolbox .. module:: aesara.graph.features
:platform: Unix, Windows :platform: Unix, Windows
:synopsis: Aesara Internals :synopsis: Aesara Internals
.. moduleauthor:: LISA .. moduleauthor:: LISA
......
...@@ -33,7 +33,7 @@ FunctionGraph ...@@ -33,7 +33,7 @@ FunctionGraph
FunctionGraph Features FunctionGraph Features
---------------------- ----------------------
.. autoclass:: aesara.graph.toolbox.Feature .. autoclass:: aesara.graph.features.Feature
:members: :members:
.. _libdoc_graph_fgraphfeaturelist: .. _libdoc_graph_fgraphfeaturelist:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
graph graph
fgraph fgraph
toolbox features
op op
type type
params_type params_type
......
...@@ -4,6 +4,7 @@ ignore = E203,E231,E501,E741,W503,W504,C901 ...@@ -4,6 +4,7 @@ ignore = E203,E231,E501,E741,W503,W504,C901
max-line-length = 88 max-line-length = 88
per-file-ignores = per-file-ignores =
**/__init__.py:F401,E402,F403 **/__init__.py:F401,E402,F403
aesara/graph/toolbox.py:E402,F403,F401
aesara/link/jax/jax_dispatch.py:E402,F403,F401 aesara/link/jax/jax_dispatch.py:E402,F403,F401
aesara/link/jax/jax_linker.py:E402,F403,F401 aesara/link/jax/jax_linker.py:E402,F403,F401
aesara/sparse/sandbox/sp2.py:F401 aesara/sparse/sandbox/sp2.py:F401
......
...@@ -15,10 +15,10 @@ from aesara.compile.debugmode import ( ...@@ -15,10 +15,10 @@ from aesara.compile.debugmode import (
) )
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.features import BadOptimization
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
from aesara.graph.opt import local_optimizer from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB from aesara.graph.optdb import EquilibriumDB
from aesara.graph.toolbox import BadOptimization
from aesara.tensor.math import add, dot, log from aesara.tensor.math import add, dot, log
from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, vector from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import aesara import aesara
from aesara.compile.mode import AddFeatureOptimizer, Mode from aesara.compile.mode import AddFeatureOptimizer, Mode
from aesara.graph.toolbox import NoOutputFromInplace from aesara.graph.features import NoOutputFromInplace
from aesara.tensor.math import dot, tanh from aesara.tensor.math import dot, tanh
from aesara.tensor.type import matrix from aesara.tensor.type import matrix
......
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, clone from aesara.graph.basic import Apply, Constant, Variable, clone
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph, InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
...@@ -14,7 +15,6 @@ from aesara.graph.opt import ( ...@@ -14,7 +15,6 @@ from aesara.graph.opt import (
PatternSub, PatternSub,
TopoOptimizer, TopoOptimizer,
) )
from aesara.graph.toolbox import ReplaceValidate
from aesara.graph.type import Type from aesara.graph.type import Type
from tests.unittest_tools import assertFailure_fast from tests.unittest_tools import assertFailure_fast
......
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.features import NodeFinder, is_same_graph
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.toolbox import NodeFinder, is_same_graph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.tensor.math import neg from aesara.tensor.math import neg
from aesara.tensor.type import vectors from aesara.tensor.type import vectors
......
...@@ -17,10 +17,10 @@ from aesara.compile.mode import Mode, get_default_mode, get_mode ...@@ -17,10 +17,10 @@ from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.features import is_same_graph
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
from aesara.graph.toolbox import is_same_graph
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch from aesara.tensor.basic import Alloc, join, switch
......
...@@ -12,8 +12,8 @@ import aesara.tensor.basic as aet ...@@ -12,8 +12,8 @@ import aesara.tensor.basic as aet
from aesara.compile import DeepCopyOp, shared from aesara.compile import DeepCopyOp, shared
from aesara.compile.io import In from aesara.compile.io import In
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.features import is_same_graph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.toolbox import is_same_graph
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import exp, isinf from aesara.tensor.math import exp, isinf
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论