提交 ced64bfc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move basic tensor rewriting code to aesara.tensor.rewriting

- `aesara.tensor.basic_opt` has been changed to `aesara.tensor.rewriting.basic` - `aesara.tensor.math_opt` has been changed to `aesara.tensor.rewriting.math` - `aesara.tensor.subtensor_opt` has been changed to `aesara.tensor.rewriting.subtensor` - `aesara.tensor.opt_uncanonicalize` has been changed to `aesara.tensor.rewriting.uncanonicalize` The tests associated with each module have been updated accordingly.
上级 57acc845
......@@ -72,9 +72,9 @@ jobs:
install-numba: [1]
part:
- "tests --ignore=tests/tensor --ignore=tests/sparse --ignore=tests/tensor/nnet"
- "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_basic_opt.py --ignore=tests/tensor/test_math_opt.py --ignore=tests/tensor/nnet"
- "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/rewriting/test_basic.py --ignore=tests/tensor/rewriting/test_math.py --ignore=tests/tensor/nnet"
- "tests/tensor/test_basic.py tests/tensor/test_math.py tests/tensor/test_math_scipy.py tests/tensor/test_inplace.py"
- "tests/tensor/test_elemwise.py tests/tensor/test_basic_opt.py tests/tensor/test_math_opt.py"
- "tests/tensor/test_elemwise.py tests/tensor/rewriting/test_basic.py tests/tensor/rewriting/test_math.py"
- "tests/tensor/nnet --ignore-glob='*/test_abstract_conv.py'"
- "tests/tensor/nnet/test_abstract_conv.py"
include:
......
......@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.rewriting.basic import ShapeFeature
def infer_shape(outs, inputs, input_shapes):
......
......@@ -2,17 +2,17 @@ import logging
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor import basic as at
from aesara.tensor.basic_opt import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import Dot, Prod, dot, log
from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod
from aesara.tensor.nlinalg import Det, MatrixInverse, trace
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.slinalg import Cholesky, Solve, cholesky, solve
......
......@@ -41,11 +41,12 @@ from aesara.scan.utils import (
safe_new,
scan_can_remove_outs,
)
from aesara.tensor import basic_opt, math_opt
from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, dot, maximum, minimum
from aesara.tensor.rewriting import basic as basic_opt
from aesara.tensor.rewriting import math as math_opt
from aesara.tensor.shape import shape
from aesara.tensor.subtensor import (
IncSubtensor,
......
......@@ -24,8 +24,8 @@ from aesara.sparse.basic import (
)
from aesara.tensor import blas
from aesara.tensor.basic import as_tensor_variable, cast
from aesara.tensor.basic_opt import register_canonicalize, register_specialize
from aesara.tensor.math import mul, neg, sub
from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
from aesara.tensor.shape import shape, specify_shape
from aesara.tensor.type import TensorType, tensor
......
......@@ -103,15 +103,13 @@ from aesara.gradient import consider_constant, grad, hessian, jacobian # noqa
# adds shared-variable constructors
from aesara.tensor import sharedvar # noqa
from aesara.tensor import ( # noqa
basic_opt,
blas,
blas_c,
blas_scipy,
nnet,
opt_uncanonicalize,
subtensor_opt,
xlogx,
)
import aesara.tensor.rewriting
# isort: off
......
......@@ -1316,7 +1316,7 @@ def infer_broadcastable(shape):
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
def check_type(s):
if s.type.dtype in integer_dtypes:
......
......@@ -159,11 +159,11 @@ from aesara.link.c.params_type import ParamsType
from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at
from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import (
DenseTensorType,
......
......@@ -806,7 +806,7 @@ class Elemwise(OpenMPOp):
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError
from aesara.tensor.exceptions import ShapeError
raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
......
......@@ -24,11 +24,6 @@ from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp
from aesara.tensor import basic as at
from aesara.tensor.basic import ARange, as_tensor_variable
from aesara.tensor.basic_opt import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import Unique
......@@ -50,8 +45,13 @@ from aesara.tensor.math import (
)
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.math_opt import local_mul_canonizer
from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from aesara.tensor.type import (
......
......@@ -8,10 +8,10 @@ from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div
from aesara.tensor import basic as at
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import mean, prod, reciprocal, sqrt
from aesara.tensor.math import sum as at_sum
from aesara.tensor.rewriting.basic import register_specialize_device
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import TensorType
......
......@@ -8,9 +8,9 @@ from aesara.graph.basic import Apply
from aesara.graph.rewriting.basic import node_rewriter
from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize
from aesara.tensor.blas import batched_dot
from aesara.tensor.extra_ops import cpu_contiguous
from aesara.tensor.rewriting.basic import register_canonicalize
from aesara.tensor.type import ftensor3, fvector
......
......@@ -13,7 +13,6 @@ from aesara.graph.rewriting.basic import (
in2out,
node_rewriter,
)
from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.nnet.abstract_conv import (
AbstractConv2d,
AbstractConv2d_gradInputs,
......@@ -34,6 +33,7 @@ from aesara.tensor.nnet.blocksparse import (
from aesara.tensor.nnet.conv import ConvOp, conv2d
from aesara.tensor.nnet.corr import CorrMM, CorrMM_gradInputs, CorrMM_gradWeights
from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGradWeights
from aesara.tensor.rewriting.basic import register_specialize_device
from aesara.tensor.type import TensorType
......
import aesara.tensor.rewriting.basic
import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize
......@@ -103,7 +103,7 @@ from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter
_logger = logging.getLogger("aesara.tensor.basic_opt")
_logger = logging.getLogger("aesara.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter())
......
......@@ -35,19 +35,6 @@ from aesara.tensor.basic import (
switch,
zeros_like,
)
from aesara.tensor.basic_opt import (
FusionOptimizer,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
local_fill_sink,
register_canonicalize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
)
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import (
......@@ -84,6 +71,19 @@ from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import (
FusionOptimizer,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
local_fill_sink,
register_canonicalize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
)
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
......@@ -648,7 +648,7 @@ class AlgebraicCanonizer(NodeRewriter):
Examples
--------
>>> import aesara.tensor as at
>>> from aesara.tensor.math_opt import AlgebraicCanonizer
>>> from aesara.tensor.rewriting.math import AlgebraicCanonizer
>>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\
... lambda n, d: sum(n) - sum(d))
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
......
......@@ -28,11 +28,6 @@ from aesara.tensor.basic import (
get_scalar_constant_value,
switch,
)
from aesara.tensor.basic_opt import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add
......@@ -50,6 +45,11 @@ from aesara.tensor.math import (
minimum,
or_,
)
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from aesara.tensor.shape import (
Shape,
SpecifyShape,
......@@ -1390,8 +1390,11 @@ def local_setsubtensor_of_constants(fgraph, node):
def local_adv_sub1_adv_inc_sub1(fgraph, node):
"""Rewrite graphs like ``AdvancedSubtensor1(AdvancedSetSubtensor1(...), ...)``.
.. code::
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
Notes
-----
This rewrite adds an `AssertOp`; otherwise, it would remove shape and index
......
......@@ -31,21 +31,16 @@ supposed to be canonical.
"""
import logging
from aesara import scalar as aes
from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.tensor.basic import Alloc, alloc, constant
from aesara.tensor.basic_opt import register_uncanonicalize
from aesara.tensor.elemwise import CAReduce, DimShuffle
from aesara.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
from aesara.tensor.rewriting.basic import register_uncanonicalize
from aesara.tensor.shape import Reshape, reshape
from aesara.tensor.subtensor import Subtensor
_logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
@register_uncanonicalize
@node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node):
......
......@@ -63,7 +63,7 @@ def shape_of_variables(fgraph, input_shapes):
"""
if not hasattr(fgraph, "shape_feature"):
fgraph.attach_feature(aesara.tensor.basic_opt.ShapeFeature())
fgraph.attach_feature(aesara.tensor.rewriting.basic.ShapeFeature())
input_dims = [
dimension
......
===================================================================
:mod:`tensor.basic_opt` -- Tensor Rewrites
===================================================================
================================================
:mod:`tensor.rewriting.basic` -- Tensor Rewrites
================================================
.. module:: tensor.basic_opt
.. module:: tensor.rewriting.basic
:platform: Unix, Windows
:synopsis: Tensor Rewrites
.. moduleauthor:: LISA, PyMC Developers, Aesara Developers
.. automodule:: aesara.tensor.basic_opt
.. automodule:: aesara.tensor.rewriting.basic
:members:
......@@ -25,8 +25,8 @@ They are grouped into the following sections:
elemwise
extra_ops
io
basic_opt
slinalg
nlinalg
fft
math_opt
basic_opt
==============================================================
:mod:`tensor.math_opt` -- Tensor Rewrites for Math Operations
==============================================================
===================================================================
:mod:`tensor.rewriting.math` -- Tensor Rewrites for Math Operations
===================================================================
.. module:: tensor.math_opt
.. module:: tensor.rewriting.math
:platform: Unix, Windows
:synopsis: Tensor Rewrites for Math Operations
.. moduleauthor:: LISA, PyMC Developers, Aesara Developers
.. automodule:: aesara.tensor.math_opt
.. automodule:: aesara.tensor.rewriting.math
:members:
......@@ -157,7 +157,7 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.basic_opt]
[mypy-aesara.tensor.rewriting.basic]
ignore_errors = True
check_untyped_defs = False
......@@ -193,7 +193,7 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.math_opt]
[mypy-aesara.tensor.rewriting.math]
ignore_errors = True
check_untyped_defs = False
......
......@@ -16,12 +16,12 @@ from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.utils import MissingInputError
from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer
from aesara.tensor.math import dot, exp
from aesara.tensor.math import round as at_round
from aesara.tensor.math import sigmoid
from aesara.tensor.math import sum as at_sum
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.rewriting.basic import ShapeOptimizer
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools
......
......@@ -23,8 +23,8 @@ from aesara.graph.rewriting.basic import (
pre_greedy_node_rewriter,
)
from aesara.raise_op import assert_op
from aesara.tensor.basic_opt import constant_folding
from aesara.tensor.math import Dot, add, dot
from aesara.tensor.rewriting.basic import constant_folding
from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import matrix, values_eq_approx_always_true
from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype
......
......@@ -15,7 +15,6 @@ from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.random.basic import (
bernoulli,
beta,
......@@ -56,6 +55,7 @@ from aesara.tensor.random.basic import (
wald,
weibull,
)
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param
......
......@@ -37,19 +37,6 @@ from aesara.tensor.basic import (
second,
tile,
)
from aesara.tensor.basic_opt import (
ShapeFeature,
assert_op,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
local_useless_alloc,
local_useless_dimshuffle_in_reshape,
local_useless_elemwise,
local_useless_reshape,
register_specialize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from aesara.tensor.math import (
......@@ -85,7 +72,20 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.rewriting.basic import (
ShapeFeature,
assert_op,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
local_useless_alloc,
local_useless_dimshuffle_in_reshape,
local_useless_elemwise,
local_useless_reshape,
register_specialize,
)
from aesara.tensor.rewriting.math import local_lift_transpose_through_dot
from aesara.tensor.shape import (
Reshape,
Shape_i,
......
......@@ -30,7 +30,6 @@ from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -80,7 +79,8 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import (
from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.rewriting.math import (
compute_mul,
is_1pexp,
local_grad_log_erfc_neg,
......
......@@ -19,6 +19,10 @@ from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Dot, add, dot, exp, sqr
from aesara.tensor.rewriting.subtensor import (
local_replace_AdvancedSubtensor,
local_subtensor_shape_constant,
)
from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -30,10 +34,6 @@ from aesara.tensor.subtensor import (
inc_subtensor,
set_subtensor,
)
from aesara.tensor.subtensor_opt import (
local_replace_AdvancedSubtensor,
local_subtensor_shape_constant,
)
from aesara.tensor.type import (
bmatrix,
col,
......
......@@ -13,7 +13,7 @@ from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import max as at_max
from aesara.tensor.math import max_and_argmax
from aesara.tensor.math import min as at_min
from aesara.tensor.opt_uncanonicalize import (
from aesara.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle,
local_dimshuffle_alloc,
local_dimshuffle_subtensor,
......
......@@ -17,10 +17,10 @@ from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second
from aesara.tensor.basic_opt import ShapeError
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any
from aesara.tensor.rewriting.basic import ShapeError
from aesara.tensor.type import (
TensorType,
bmatrix,
......
......@@ -11,8 +11,8 @@ from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.shape import (
Reshape,
Shape_i,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论