提交 63f52536 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split aesara.tensor.rewriting.basic rewrites by their aesara.tensor modules

上级 9704ed42
......@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes):
......
......@@ -45,8 +45,9 @@ from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, dot, maximum, minimum
from aesara.tensor.rewriting import basic as basic_opt
from aesara.tensor.rewriting import math as math_opt
from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch
from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from aesara.tensor.shape import shape
from aesara.tensor.subtensor import (
IncSubtensor,
......@@ -60,11 +61,11 @@ from aesara.tensor.var import TensorConstant, get_unique_value
list_opt_slice = [
math_opt.local_abs_merge,
math_opt.local_mul_switch_sink,
basic_opt.local_upcast_elemwise_constant_inputs,
basic_opt.local_useless_switch,
basic_opt.constant_folding,
local_abs_merge,
local_mul_switch_sink,
local_upcast_elemwise_constant_inputs,
local_useless_switch,
constant_folding,
]
......@@ -2432,7 +2433,7 @@ scan_seqopt1.register(
scan_eqopt2.register(
"constant_folding_for_scan2",
in2out(basic_opt.constant_folding, ignore_newtrees=True),
in2out(constant_folding, ignore_newtrees=True),
"fast_run",
"scan",
)
......
......@@ -29,7 +29,7 @@ from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint
from aesara.printing import Printer, min_informative_str, pprint, set_precedence
from aesara.raise_op import CheckAndRaise, assert_op
from aesara.scalar import int32
from aesara.scalar.basic import ScalarConstant, ScalarVariable
......@@ -1335,7 +1335,8 @@ def infer_broadcastable(shape):
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
def check_type(s):
if s.type.dtype in integer_dtypes:
......@@ -1709,6 +1710,21 @@ class MakeVector(COp):
make_vector = MakeVector()
class MakeVectorPrinter(Printer):
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
with set_precedence(pstate):
s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
return f"[{', '.join(s)}]"
else:
raise TypeError("Can only print make_vector.")
pprint.assign(MakeVector, MakeVectorPrinter())
@_get_vector_length.register(MakeVector)
def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)
......
......@@ -8,3 +8,6 @@ warnings.warn(
)
from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.elemwise import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.extra_ops import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.shape import * # noqa: F401 E402 F403
......@@ -163,7 +163,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import (
DenseTensorType,
......
import aesara.tensor.rewriting.basic
import aesara.tensor.rewriting.elemwise
import aesara.tensor.rewriting.extra_ops
import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.shape
import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize
差异被折叠。
import aesara.scalar.basic as aes
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.basic import Alloc, as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique
from aesara.tensor.rewriting.basic import register_canonicalize, register_useless
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False
if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False
uniqued_var = node.inputs[0]
if uniqued_var.ndim != 0:
return False
old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
alloc_var = node.inputs[0]
if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False
alloced_var, *alloc_shape = alloc_var.owner.inputs
new_unique, *_ = node.op.make_node(alloced_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
repeat_var = node.inputs[0]
if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False
repeated_var, *repeat_shape = repeat_var.owner.inputs
new_unique, *_ = node.op.make_node(repeated_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
second_var = node.inputs[0]
if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False
shape_var, seconded_var = second_var.owner.inputs
new_unique, *_ = node.op.make_node(seconded_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
......@@ -72,10 +72,8 @@ from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sq
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import (
FusionOptimizer,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
local_fill_sink,
register_canonicalize,
register_specialize,
......@@ -84,6 +82,7 @@ from aesara.tensor.rewriting.basic import (
register_uncanonicalize,
register_useless,
)
from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论