提交 63f52536 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split aesara.tensor.rewriting.basic rewrites by their aesara.tensor modules

上级 9704ed42
...@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType ...@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.tensor.rewriting.basic import ShapeFeature from aesara.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
......
...@@ -45,8 +45,9 @@ from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value ...@@ -45,8 +45,9 @@ from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, dot, maximum, minimum from aesara.tensor.math import Dot, dot, maximum, minimum
from aesara.tensor.rewriting import basic as basic_opt from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch
from aesara.tensor.rewriting import math as math_opt from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from aesara.tensor.shape import shape from aesara.tensor.shape import shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
IncSubtensor, IncSubtensor,
...@@ -60,11 +61,11 @@ from aesara.tensor.var import TensorConstant, get_unique_value ...@@ -60,11 +61,11 @@ from aesara.tensor.var import TensorConstant, get_unique_value
list_opt_slice = [ list_opt_slice = [
math_opt.local_abs_merge, local_abs_merge,
math_opt.local_mul_switch_sink, local_mul_switch_sink,
basic_opt.local_upcast_elemwise_constant_inputs, local_upcast_elemwise_constant_inputs,
basic_opt.local_useless_switch, local_useless_switch,
basic_opt.constant_folding, constant_folding,
] ]
...@@ -2432,7 +2433,7 @@ scan_seqopt1.register( ...@@ -2432,7 +2433,7 @@ scan_seqopt1.register(
scan_eqopt2.register( scan_eqopt2.register(
"constant_folding_for_scan2", "constant_folding_for_scan2",
in2out(basic_opt.constant_folding, ignore_newtrees=True), in2out(constant_folding, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
) )
......
...@@ -29,7 +29,7 @@ from aesara.graph.type import Type ...@@ -29,7 +29,7 @@ from aesara.graph.type import Type
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint from aesara.printing import Printer, min_informative_str, pprint, set_precedence
from aesara.raise_op import CheckAndRaise, assert_op from aesara.raise_op import CheckAndRaise, assert_op
from aesara.scalar import int32 from aesara.scalar import int32
from aesara.scalar.basic import ScalarConstant, ScalarVariable from aesara.scalar.basic import ScalarConstant, ScalarVariable
...@@ -1335,7 +1335,8 @@ def infer_broadcastable(shape): ...@@ -1335,7 +1335,8 @@ def infer_broadcastable(shape):
`shape` will be validated and constant folded in order to determine `shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``). which dimensions are broadcastable (i.e. equal to ``1``).
""" """
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
def check_type(s): def check_type(s):
if s.type.dtype in integer_dtypes: if s.type.dtype in integer_dtypes:
...@@ -1709,6 +1710,21 @@ class MakeVector(COp): ...@@ -1709,6 +1710,21 @@ class MakeVector(COp):
make_vector = MakeVector() make_vector = MakeVector()
class MakeVectorPrinter(Printer):
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
with set_precedence(pstate):
s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
return f"[{', '.join(s)}]"
else:
raise TypeError("Can only print make_vector.")
pprint.assign(MakeVector, MakeVectorPrinter())
@_get_vector_length.register(MakeVector) @_get_vector_length.register(MakeVector)
def _get_vector_length_MakeVector(op, var): def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs) return len(var.owner.inputs)
......
...@@ -8,3 +8,6 @@ warnings.warn( ...@@ -8,3 +8,6 @@ warnings.warn(
) )
from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403 from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.elemwise import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.extra_ops import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.shape import * # noqa: F401 E402 F403
...@@ -163,7 +163,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version ...@@ -163,7 +163,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.rewriting.basic import local_dimshuffle_lift from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
from aesara.tensor.shape import specify_broadcastable from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import ( from aesara.tensor.type import (
DenseTensorType, DenseTensorType,
......
import aesara.tensor.rewriting.basic import aesara.tensor.rewriting.basic
import aesara.tensor.rewriting.elemwise
import aesara.tensor.rewriting.extra_ops
import aesara.tensor.rewriting.math import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.shape
import aesara.tensor.rewriting.subtensor import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize import aesara.tensor.rewriting.uncanonicalize
差异被折叠。
import aesara.scalar.basic as aes
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.basic import Alloc, as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique
from aesara.tensor.rewriting.basic import register_canonicalize, register_useless
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False
if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False
uniqued_var = node.inputs[0]
if uniqued_var.ndim != 0:
return False
old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
alloc_var = node.inputs[0]
if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False
alloced_var, *alloc_shape = alloc_var.owner.inputs
new_unique, *_ = node.op.make_node(alloced_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
repeat_var = node.inputs[0]
if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False
repeated_var, *repeat_shape = repeat_var.owner.inputs
new_unique, *_ = node.op.make_node(repeated_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
second_var = node.inputs[0]
if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False
shape_var, seconded_var = second_var.owner.inputs
new_unique, *_ = node.op.make_node(seconded_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
...@@ -72,10 +72,8 @@ from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sq ...@@ -72,10 +72,8 @@ from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sq
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import ( from aesara.tensor.rewriting.basic import (
FusionOptimizer,
broadcast_like, broadcast_like,
encompasses_broadcastable, encompasses_broadcastable,
fuse_seqopt,
local_fill_sink, local_fill_sink,
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
...@@ -84,6 +82,7 @@ from aesara.tensor.rewriting.basic import ( ...@@ -84,6 +82,7 @@ from aesara.tensor.rewriting.basic import (
register_uncanonicalize, register_uncanonicalize,
register_useless, register_useless,
) )
from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from aesara.tensor.shape import Shape, Shape_i from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
......
差异被折叠。
...@@ -63,7 +63,9 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -63,7 +63,9 @@ def shape_of_variables(fgraph, input_shapes):
""" """
if not hasattr(fgraph, "shape_feature"): if not hasattr(fgraph, "shape_feature"):
fgraph.attach_feature(aesara.tensor.rewriting.basic.ShapeFeature()) from aesara.tensor.rewriting.shape import ShapeFeature
fgraph.attach_feature(ShapeFeature())
input_dims = [ input_dims = [
dimension dimension
......
...@@ -21,7 +21,7 @@ from aesara.tensor.math import round as at_round ...@@ -21,7 +21,7 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sigmoid from aesara.tensor.math import sigmoid
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.rewriting.basic import ShapeOptimizer from aesara.tensor.rewriting.shape import ShapeOptimizer
from aesara.tensor.shape import specify_shape from aesara.tensor.shape import specify_shape
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools from tests import unittest_tools
......
...@@ -55,7 +55,7 @@ from aesara.tensor.random.basic import ( ...@@ -55,7 +55,7 @@ from aesara.tensor.random.basic import (
wald, wald,
weibull, weibull,
) )
from aesara.tensor.rewriting.basic import ShapeFeature from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.type import iscalar, scalar, tensor from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param from tests.unittest_tools import create_aesara_param
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论