提交 ced64bfc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move basic tensor rewriting code to aesara.tensor.rewriting

- `aesara.tensor.basic_opt` has been changed to `aesara.tensor.rewriting.basic` - `aesara.tensor.math_opt` has been changed to `aesara.tensor.rewriting.math` - `aesara.tensor.subtensor_opt` has been changed to `aesara.tensor.rewriting.subtensor` - `aesara.tensor.opt_uncanonicalize` has been changed to `aesara.tensor.rewriting.uncanonicalize` The tests associated with each module have been updated accordingly.
上级 57acc845
...@@ -72,9 +72,9 @@ jobs: ...@@ -72,9 +72,9 @@ jobs:
install-numba: [1] install-numba: [1]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/sparse --ignore=tests/tensor/nnet" - "tests --ignore=tests/tensor --ignore=tests/sparse --ignore=tests/tensor/nnet"
- "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_basic_opt.py --ignore=tests/tensor/test_math_opt.py --ignore=tests/tensor/nnet" - "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/rewriting/test_basic.py --ignore=tests/tensor/rewriting/test_math.py --ignore=tests/tensor/nnet"
- "tests/tensor/test_basic.py tests/tensor/test_math.py tests/tensor/test_math_scipy.py tests/tensor/test_inplace.py" - "tests/tensor/test_basic.py tests/tensor/test_math.py tests/tensor/test_math_scipy.py tests/tensor/test_inplace.py"
- "tests/tensor/test_elemwise.py tests/tensor/test_basic_opt.py tests/tensor/test_math_opt.py" - "tests/tensor/test_elemwise.py tests/tensor/rewriting/test_basic.py tests/tensor/rewriting/test_math.py"
- "tests/tensor/nnet --ignore-glob='*/test_abstract_conv.py'" - "tests/tensor/nnet --ignore-glob='*/test_abstract_conv.py'"
- "tests/tensor/nnet/test_abstract_conv.py" - "tests/tensor/nnet/test_abstract_conv.py"
include: include:
......
...@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType ...@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.rewriting.basic import ShapeFeature
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
......
...@@ -2,17 +2,17 @@ import logging ...@@ -2,17 +2,17 @@ import logging
from aesara.graph.rewriting.basic import node_rewriter from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic_opt import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.blas import Dot22 from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import Dot, Prod, dot, log from aesara.tensor.math import Dot, Prod, dot, log
from aesara.tensor.math import pow as at_pow from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod from aesara.tensor.math import prod
from aesara.tensor.nlinalg import Det, MatrixInverse, trace from aesara.tensor.nlinalg import Det, MatrixInverse, trace
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.slinalg import Cholesky, Solve, cholesky, solve from aesara.tensor.slinalg import Cholesky, Solve, cholesky, solve
......
...@@ -41,11 +41,12 @@ from aesara.scan.utils import ( ...@@ -41,11 +41,12 @@ from aesara.scan.utils import (
safe_new, safe_new,
scan_can_remove_outs, scan_can_remove_outs,
) )
from aesara.tensor import basic_opt, math_opt
from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, dot, maximum, minimum from aesara.tensor.math import Dot, dot, maximum, minimum
from aesara.tensor.rewriting import basic as basic_opt
from aesara.tensor.rewriting import math as math_opt
from aesara.tensor.shape import shape from aesara.tensor.shape import shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
IncSubtensor, IncSubtensor,
......
...@@ -24,8 +24,8 @@ from aesara.sparse.basic import ( ...@@ -24,8 +24,8 @@ from aesara.sparse.basic import (
) )
from aesara.tensor import blas from aesara.tensor import blas
from aesara.tensor.basic import as_tensor_variable, cast from aesara.tensor.basic import as_tensor_variable, cast
from aesara.tensor.basic_opt import register_canonicalize, register_specialize
from aesara.tensor.math import mul, neg, sub from aesara.tensor.math import mul, neg, sub
from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
from aesara.tensor.shape import shape, specify_shape from aesara.tensor.shape import shape, specify_shape
from aesara.tensor.type import TensorType, tensor from aesara.tensor.type import TensorType, tensor
......
...@@ -103,15 +103,13 @@ from aesara.gradient import consider_constant, grad, hessian, jacobian # noqa ...@@ -103,15 +103,13 @@ from aesara.gradient import consider_constant, grad, hessian, jacobian # noqa
# adds shared-variable constructors # adds shared-variable constructors
from aesara.tensor import sharedvar # noqa from aesara.tensor import sharedvar # noqa
from aesara.tensor import ( # noqa from aesara.tensor import ( # noqa
basic_opt,
blas, blas,
blas_c, blas_c,
blas_scipy, blas_scipy,
nnet, nnet,
opt_uncanonicalize,
subtensor_opt,
xlogx, xlogx,
) )
import aesara.tensor.rewriting
# isort: off # isort: off
......
...@@ -1316,7 +1316,7 @@ def infer_broadcastable(shape): ...@@ -1316,7 +1316,7 @@ def infer_broadcastable(shape):
`shape` will be validated and constant folded in order to determine `shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``). which dimensions are broadcastable (i.e. equal to ``1``).
""" """
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
def check_type(s): def check_type(s):
if s.type.dtype in integer_dtypes: if s.type.dtype in integer_dtypes:
......
...@@ -159,11 +159,11 @@ from aesara.link.c.params_type import ParamsType ...@@ -159,11 +159,11 @@ from aesara.link.c.params_type import ParamsType
from aesara.printing import FunctionPrinter, debugprint, pprint from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas_headers import blas_header_text, blas_header_version from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.shape import specify_broadcastable from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import ( from aesara.tensor.type import (
DenseTensorType, DenseTensorType,
......
...@@ -806,7 +806,7 @@ class Elemwise(OpenMPOp): ...@@ -806,7 +806,7 @@ class Elemwise(OpenMPOp):
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1: if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError from aesara.tensor.exceptions import ShapeError
raise ShapeError( raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`" "Multiple outputs are not supported by the default `Elemwise.infer_shape`"
......
...@@ -24,11 +24,6 @@ from aesara.raise_op import Assert ...@@ -24,11 +24,6 @@ from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp from aesara.scalar import UnaryScalarOp
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic import ARange, as_tensor_variable from aesara.tensor.basic import ARange, as_tensor_variable
from aesara.tensor.basic_opt import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import Unique from aesara.tensor.extra_ops import Unique
...@@ -50,8 +45,13 @@ from aesara.tensor.math import ( ...@@ -50,8 +45,13 @@ from aesara.tensor.math import (
) )
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh, tensordot, true_div from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.math_opt import local_mul_canonizer
from aesara.tensor.nnet.blocksparse import sparse_block_dot from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.tensor.shape import Shape, shape_padleft from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
......
...@@ -8,10 +8,10 @@ from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter ...@@ -8,10 +8,10 @@ from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import mean, prod, reciprocal, sqrt from aesara.tensor.math import mean, prod, reciprocal, sqrt
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.rewriting.basic import register_specialize_device
from aesara.tensor.shape import specify_broadcastable from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
......
...@@ -8,9 +8,9 @@ from aesara.graph.basic import Apply ...@@ -8,9 +8,9 @@ from aesara.graph.basic import Apply
from aesara.graph.rewriting.basic import node_rewriter from aesara.graph.rewriting.basic import node_rewriter
from aesara.link.c.cmodule import GCC_compiler from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize
from aesara.tensor.blas import batched_dot from aesara.tensor.blas import batched_dot
from aesara.tensor.extra_ops import cpu_contiguous from aesara.tensor.extra_ops import cpu_contiguous
from aesara.tensor.rewriting.basic import register_canonicalize
from aesara.tensor.type import ftensor3, fvector from aesara.tensor.type import ftensor3, fvector
......
...@@ -13,7 +13,6 @@ from aesara.graph.rewriting.basic import ( ...@@ -13,7 +13,6 @@ from aesara.graph.rewriting.basic import (
in2out, in2out,
node_rewriter, node_rewriter,
) )
from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.nnet.abstract_conv import ( from aesara.tensor.nnet.abstract_conv import (
AbstractConv2d, AbstractConv2d,
AbstractConv2d_gradInputs, AbstractConv2d_gradInputs,
...@@ -34,6 +33,7 @@ from aesara.tensor.nnet.blocksparse import ( ...@@ -34,6 +33,7 @@ from aesara.tensor.nnet.blocksparse import (
from aesara.tensor.nnet.conv import ConvOp, conv2d from aesara.tensor.nnet.conv import ConvOp, conv2d
from aesara.tensor.nnet.corr import CorrMM, CorrMM_gradInputs, CorrMM_gradWeights from aesara.tensor.nnet.corr import CorrMM, CorrMM_gradInputs, CorrMM_gradWeights
from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGradWeights from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGradWeights
from aesara.tensor.rewriting.basic import register_specialize_device
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
......
import aesara.tensor.rewriting.basic
import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize
...@@ -103,7 +103,7 @@ from aesara.tensor.var import TensorConstant ...@@ -103,7 +103,7 @@ from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
_logger = logging.getLogger("aesara.tensor.basic_opt") _logger = logging.getLogger("aesara.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
......
...@@ -35,19 +35,6 @@ from aesara.tensor.basic import ( ...@@ -35,19 +35,6 @@ from aesara.tensor.basic import (
switch, switch,
zeros_like, zeros_like,
) )
from aesara.tensor.basic_opt import (
FusionOptimizer,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
local_fill_sink,
register_canonicalize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
)
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import ( from aesara.tensor.math import (
...@@ -84,6 +71,19 @@ from aesara.tensor.math import pow as at_pow ...@@ -84,6 +71,19 @@ from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import (
FusionOptimizer,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
local_fill_sink,
register_canonicalize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
)
from aesara.tensor.shape import Shape, Shape_i from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -648,7 +648,7 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -648,7 +648,7 @@ class AlgebraicCanonizer(NodeRewriter):
Examples Examples
-------- --------
>>> import aesara.tensor as at >>> import aesara.tensor as at
>>> from aesara.tensor.math_opt import AlgebraicCanonizer >>> from aesara.tensor.rewriting.math import AlgebraicCanonizer
>>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\ >>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\
... lambda n, d: sum(n) - sum(d)) ... lambda n, d: sum(n) - sum(d))
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\ >>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
......
...@@ -28,11 +28,6 @@ from aesara.tensor.basic import ( ...@@ -28,11 +28,6 @@ from aesara.tensor.basic import (
get_scalar_constant_value, get_scalar_constant_value,
switch, switch,
) )
from aesara.tensor.basic_opt import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add from aesara.tensor.math import Dot, add
...@@ -50,6 +45,11 @@ from aesara.tensor.math import ( ...@@ -50,6 +45,11 @@ from aesara.tensor.math import (
minimum, minimum,
or_, or_,
) )
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.shape import ( from aesara.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
...@@ -1390,8 +1390,11 @@ def local_setsubtensor_of_constants(fgraph, node): ...@@ -1390,8 +1390,11 @@ def local_setsubtensor_of_constants(fgraph, node):
def local_adv_sub1_adv_inc_sub1(fgraph, node): def local_adv_sub1_adv_inc_sub1(fgraph, node):
"""Rewrite graphs like ``AdvancedSubtensor1(AdvancedSetSubtensor1(...), ...)``. """Rewrite graphs like ``AdvancedSubtensor1(AdvancedSetSubtensor1(...), ...)``.
.. code::
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
Notes Notes
----- -----
This rewrite adds an `AssertOp`; otherwise, it would remove shape and index This rewrite adds an `AssertOp`; otherwise, it would remove shape and index
......
...@@ -31,21 +31,16 @@ supposed to be canonical. ...@@ -31,21 +31,16 @@ supposed to be canonical.
""" """
import logging
from aesara import scalar as aes from aesara import scalar as aes
from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.tensor.basic import Alloc, alloc, constant from aesara.tensor.basic import Alloc, alloc, constant
from aesara.tensor.basic_opt import register_uncanonicalize
from aesara.tensor.elemwise import CAReduce, DimShuffle from aesara.tensor.elemwise import CAReduce, DimShuffle
from aesara.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg from aesara.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
from aesara.tensor.rewriting.basic import register_uncanonicalize
from aesara.tensor.shape import Reshape, reshape from aesara.tensor.shape import Reshape, reshape
from aesara.tensor.subtensor import Subtensor from aesara.tensor.subtensor import Subtensor
_logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
@register_uncanonicalize @register_uncanonicalize
@node_rewriter([MaxAndArgmax]) @node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node): def local_max_and_argmax(fgraph, node):
......
...@@ -63,7 +63,7 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -63,7 +63,7 @@ def shape_of_variables(fgraph, input_shapes):
""" """
if not hasattr(fgraph, "shape_feature"): if not hasattr(fgraph, "shape_feature"):
fgraph.attach_feature(aesara.tensor.basic_opt.ShapeFeature()) fgraph.attach_feature(aesara.tensor.rewriting.basic.ShapeFeature())
input_dims = [ input_dims = [
dimension dimension
......
=================================================================== ================================================
:mod:`tensor.basic_opt` -- Tensor Rewrites :mod:`tensor.rewriting.basic` -- Tensor Rewrites
=================================================================== ================================================
.. module:: tensor.basic_opt .. module:: tensor.rewriting.basic
:platform: Unix, Windows :platform: Unix, Windows
:synopsis: Tensor Rewrites :synopsis: Tensor Rewrites
.. moduleauthor:: LISA, PyMC Developers, Aesara Developers .. moduleauthor:: LISA, PyMC Developers, Aesara Developers
.. automodule:: aesara.tensor.basic_opt .. automodule:: aesara.tensor.rewriting.basic
:members: :members:
...@@ -25,8 +25,8 @@ They are grouped into the following sections: ...@@ -25,8 +25,8 @@ They are grouped into the following sections:
elemwise elemwise
extra_ops extra_ops
io io
basic_opt
slinalg slinalg
nlinalg nlinalg
fft fft
math_opt math_opt
basic_opt
============================================================== ===================================================================
:mod:`tensor.math_opt` -- Tensor Rewrites for Math Operations :mod:`tensor.rewriting.math` -- Tensor Rewrites for Math Operations
============================================================== ===================================================================
.. module:: tensor.math_opt .. module:: tensor.rewriting.math
:platform: Unix, Windows :platform: Unix, Windows
:synopsis: Tensor Rewrites for Math Operations :synopsis: Tensor Rewrites for Math Operations
.. moduleauthor:: LISA, PyMC Developers, Aesara Developers .. moduleauthor:: LISA, PyMC Developers, Aesara Developers
.. automodule:: aesara.tensor.math_opt .. automodule:: aesara.tensor.rewriting.math
:members: :members:
...@@ -157,7 +157,7 @@ check_untyped_defs = False ...@@ -157,7 +157,7 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.tensor.basic_opt] [mypy-aesara.tensor.rewriting.basic]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
...@@ -193,7 +193,7 @@ check_untyped_defs = False ...@@ -193,7 +193,7 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.tensor.math_opt] [mypy-aesara.tensor.rewriting.math]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
...@@ -16,12 +16,12 @@ from aesara.graph.rewriting.utils import rewrite_graph ...@@ -16,12 +16,12 @@ from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.printing import debugprint from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer
from aesara.tensor.math import dot, exp from aesara.tensor.math import dot, exp
from aesara.tensor.math import round as at_round from aesara.tensor.math import round as at_round
from aesara.tensor.math import sigmoid from aesara.tensor.math import sigmoid
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.rewriting.basic import ShapeOptimizer
from aesara.tensor.shape import specify_shape from aesara.tensor.shape import specify_shape
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools from tests import unittest_tools
......
...@@ -23,8 +23,8 @@ from aesara.graph.rewriting.basic import ( ...@@ -23,8 +23,8 @@ from aesara.graph.rewriting.basic import (
pre_greedy_node_rewriter, pre_greedy_node_rewriter,
) )
from aesara.raise_op import assert_op from aesara.raise_op import assert_op
from aesara.tensor.basic_opt import constant_folding
from aesara.tensor.math import Dot, add, dot from aesara.tensor.math import Dot, add, dot
from aesara.tensor.rewriting.basic import constant_folding
from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import matrix, values_eq_approx_always_true from aesara.tensor.type import matrix, values_eq_approx_always_true
from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype
......
...@@ -15,7 +15,6 @@ from aesara.graph.basic import Constant, Variable, graph_inputs ...@@ -15,7 +15,6 @@ from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.random.basic import ( from aesara.tensor.random.basic import (
bernoulli, bernoulli,
beta, beta,
...@@ -56,6 +55,7 @@ from aesara.tensor.random.basic import ( ...@@ -56,6 +55,7 @@ from aesara.tensor.random.basic import (
wald, wald,
weibull, weibull,
) )
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.type import iscalar, scalar, tensor from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param from tests.unittest_tools import create_aesara_param
......
...@@ -37,19 +37,6 @@ from aesara.tensor.basic import ( ...@@ -37,19 +37,6 @@ from aesara.tensor.basic import (
second, second,
tile, tile,
) )
from aesara.tensor.basic_opt import (
ShapeFeature,
assert_op,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
local_useless_alloc,
local_useless_dimshuffle_in_reshape,
local_useless_elemwise,
local_useless_reshape,
register_specialize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from aesara.tensor.math import ( from aesara.tensor.math import (
...@@ -85,7 +72,20 @@ from aesara.tensor.math import round as at_round ...@@ -85,7 +72,20 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot from aesara.tensor.rewriting.basic import (
ShapeFeature,
assert_op,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
local_useless_alloc,
local_useless_dimshuffle_in_reshape,
local_useless_elemwise,
local_useless_reshape,
register_specialize,
)
from aesara.tensor.rewriting.math import local_lift_transpose_through_dot
from aesara.tensor.shape import ( from aesara.tensor.shape import (
Reshape, Reshape,
Shape_i, Shape_i,
......
...@@ -30,7 +30,6 @@ from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph ...@@ -30,7 +30,6 @@ from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -80,7 +79,8 @@ from aesara.tensor.math import round as at_round ...@@ -80,7 +79,8 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import ( from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.rewriting.math import (
compute_mul, compute_mul,
is_1pexp, is_1pexp,
local_grad_log_erfc_neg, local_grad_log_erfc_neg,
......
...@@ -19,6 +19,10 @@ from aesara.tensor import inplace ...@@ -19,6 +19,10 @@ from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Dot, add, dot, exp, sqr from aesara.tensor.math import Dot, add, dot, exp, sqr
from aesara.tensor.rewriting.subtensor import (
local_replace_AdvancedSubtensor,
local_subtensor_shape_constant,
)
from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -30,10 +34,6 @@ from aesara.tensor.subtensor import ( ...@@ -30,10 +34,6 @@ from aesara.tensor.subtensor import (
inc_subtensor, inc_subtensor,
set_subtensor, set_subtensor,
) )
from aesara.tensor.subtensor_opt import (
local_replace_AdvancedSubtensor,
local_subtensor_shape_constant,
)
from aesara.tensor.type import ( from aesara.tensor.type import (
bmatrix, bmatrix,
col, col,
......
...@@ -13,7 +13,7 @@ from aesara.tensor.math import MaxAndArgmax ...@@ -13,7 +13,7 @@ from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import max as at_max from aesara.tensor.math import max as at_max
from aesara.tensor.math import max_and_argmax from aesara.tensor.math import max_and_argmax
from aesara.tensor.math import min as at_min from aesara.tensor.math import min as at_min
from aesara.tensor.opt_uncanonicalize import ( from aesara.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle, local_alloc_dimshuffle,
local_dimshuffle_alloc, local_dimshuffle_alloc,
local_dimshuffle_subtensor, local_dimshuffle_subtensor,
......
...@@ -17,10 +17,10 @@ from aesara.link.basic import PerformLinker ...@@ -17,10 +17,10 @@ from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second from aesara.tensor.basic import second
from aesara.tensor.basic_opt import ShapeError
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any from aesara.tensor.math import any as at_any
from aesara.tensor.rewriting.basic import ShapeError
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
bmatrix, bmatrix,
......
...@@ -11,8 +11,8 @@ from aesara.graph.type import Type ...@@ -11,8 +11,8 @@ from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import as_tensor_variable, get_vector_length, row from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.shape import ( from aesara.tensor.shape import (
Reshape, Reshape,
Shape_i, Shape_i,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论