提交 63f52536 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split aesara.tensor.rewriting.basic rewrites by their aesara.tensor modules

上级 9704ed42
......@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes):
......
......@@ -45,8 +45,9 @@ from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, dot, maximum, minimum
from aesara.tensor.rewriting import basic as basic_opt
from aesara.tensor.rewriting import math as math_opt
from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch
from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from aesara.tensor.shape import shape
from aesara.tensor.subtensor import (
IncSubtensor,
......@@ -60,11 +61,11 @@ from aesara.tensor.var import TensorConstant, get_unique_value
list_opt_slice = [
math_opt.local_abs_merge,
math_opt.local_mul_switch_sink,
basic_opt.local_upcast_elemwise_constant_inputs,
basic_opt.local_useless_switch,
basic_opt.constant_folding,
local_abs_merge,
local_mul_switch_sink,
local_upcast_elemwise_constant_inputs,
local_useless_switch,
constant_folding,
]
......@@ -2432,7 +2433,7 @@ scan_seqopt1.register(
scan_eqopt2.register(
"constant_folding_for_scan2",
in2out(basic_opt.constant_folding, ignore_newtrees=True),
in2out(constant_folding, ignore_newtrees=True),
"fast_run",
"scan",
)
......
......@@ -29,7 +29,7 @@ from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint
from aesara.printing import Printer, min_informative_str, pprint, set_precedence
from aesara.raise_op import CheckAndRaise, assert_op
from aesara.scalar import int32
from aesara.scalar.basic import ScalarConstant, ScalarVariable
......@@ -1335,7 +1335,8 @@ def infer_broadcastable(shape):
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding
from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
def check_type(s):
if s.type.dtype in integer_dtypes:
......@@ -1709,6 +1710,21 @@ class MakeVector(COp):
make_vector = MakeVector()
class MakeVectorPrinter(Printer):
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
with set_precedence(pstate):
s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
return f"[{', '.join(s)}]"
else:
raise TypeError("Can only print make_vector.")
pprint.assign(MakeVector, MakeVectorPrinter())
@_get_vector_length.register(MakeVector)
def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)
......
......@@ -8,3 +8,6 @@ warnings.warn(
)
from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.elemwise import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.extra_ops import * # noqa: F401 E402 F403
from aesara.tensor.rewriting.shape import * # noqa: F401 E402 F403
......@@ -163,7 +163,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import (
DenseTensorType,
......
import aesara.tensor.rewriting.basic
import aesara.tensor.rewriting.elemwise
import aesara.tensor.rewriting.extra_ops
import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.shape
import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize
""" Tensor optimizations addressing the ops in basic.py."""
import logging
import sys
import time
import traceback
from collections import defaultdict
from io import StringIO
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import aesara
import aesara.scalar.basic as aes
from aesara import compile
from aesara.compile.ops import ViewOp
from aesara.configdefaults import config
from aesara.graph.basic import (
Apply,
Constant,
Variable,
ancestors,
equal_computations,
io_toposort,
)
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.basic import Constant, Variable
from aesara.graph.rewriting.basic import (
GraphRewriter,
NodeRewriter,
RemovalNodeRewriter,
Rewriter,
check_chain,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.graph.rewriting.db import RewriteDatabase, SequenceDB
from aesara.graph.utils import (
InconsistencyError,
MethodNotDefined,
TestValueError,
get_variable_trace_string,
)
from aesara.printing import Printer, pprint, set_precedence
from aesara.graph.rewriting.db import RewriteDatabase
from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.tensor.basic import (
Alloc,
......@@ -56,53 +29,31 @@ from aesara.tensor.basic import (
alloc,
as_tensor_variable,
cast,
constant,
extract_constant,
fill,
get_scalar_constant_value,
join,
ones_like,
stack,
switch,
tensor_copy,
zeros,
zeros_like,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_shape,
broadcast_to,
)
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import broadcast_shape, broadcast_to
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
shape_i,
shape_padleft,
specify_shape,
unbroadcast,
)
from aesara.tensor.shape import Shape_i
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import (
DenseTensorType,
TensorType,
discrete_dtypes,
integer_dtypes,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.type import DenseTensorType, TensorType
from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter
if TYPE_CHECKING:
from aesara.tensor.rewriting.shape import ShapeFeature
_logger = logging.getLogger("aesara.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter())
......@@ -164,320 +115,6 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval
class InplaceElemwiseOptimizer(GraphRewriter):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
def __init__(self, OP):
self.op = OP
def add_requirements(self, fgraph):
from aesara.graph.destroyhandler import DestroyHandler
fgraph.attach_feature(DestroyHandler())
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, prof["opt"].op, file=stream)
for k in [
"node_before",
"nb_call_replace",
"nb_call_validate",
"nb_inconsistent",
]:
print(blanc, k, prof[k], file=stream)
ndim = prof["ndim"]
if ndim:
print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim.keys()):
print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph):
r"""
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Examples
--------
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
# We should not validate too often as this takes too much time to
# execute!
# It is the _dfs_toposort() fct in aesara/graph/destroyhandler.py
# that takes so much time.
# Should we try to use another lib that does toposort?
# igraph: http://igraph.sourceforge.net/
# networkx: https://networkx.lanl.gov/
# Should we try to use cython?
# Compiling only that fct is not enough, should we try to add the
# deque class too?
# And init the deque and other list to an upper bound number of
# elements?
# Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic
#
# The next longest rewriter is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
# We execute `validate` after this number of change.
prof = {
"opt": self,
"node_before": len(fgraph.apply_nodes),
"nb_call_replace": 0,
"nb_call_validate": 0,
"nb_inconsistent": 0,
"ndim": defaultdict(lambda: 0),
}
check_each_change = config.tensor__insert_inplace_optimizer_validate_nb
if check_each_change == -1:
if len(fgraph.apply_nodes) > 500:
check_each_change = 10
else:
check_each_change = 1
nb_change_no_validate = 0
chk = fgraph.checkpoint()
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
else:
update_outs = []
protected_inputs = [
f.protected
for f in fgraph._features
if isinstance(f, aesara.compile.function.types.Supervisor)
]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
if not isinstance(op, self.op):
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
if (
check_each_change != 1
and
# If multiple outputs, they must all have the same size,
# so only check the first.
getattr(node.outputs[0].type, "ndim", -1) == 0
):
continue
if op.inplace_pattern:
# Maybe this isn't needed anymore, but I don't want to
# rish regression now. This case only happen if the
# original node add already some inplace patter and we
# still try to add more pattern.
baseline = op.inplace_pattern
candidate_outputs = [
i for i in range(len(node.outputs)) if i not in baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
# inplace target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if i not in baseline.values()
and not isinstance(node.inputs[i], Constant)
and
# the next line should not be costly most of the time.
not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
else:
baseline = []
candidate_outputs = list(range(len(node.outputs)))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if not isinstance(node.inputs[i], Constant)
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
verbose = False
raised_warning = not verbose
for candidate_output in candidate_outputs:
# If the output of the node can be established as an update
# output of the fgraph, visit the candidate_inputs in an order
# that will improve the chances of making the node operate
# inplace on the input it's meant to update
candidate_out_var = node.outputs[candidate_output]
sorted_candidate_inputs = candidate_inputs
if candidate_out_var in update_outs:
# The candidate output is an update. Sort the
# variables in candidate_inputs in the following order:
# - Vars corresponding to the actual updated input
# (best case scenario is for the node that procudes
# an update to operate inplace on the variable to
# update)
# - Vars computed inplace on the updates input (second
# best scenario if for the node to work inplace on
# a variable obtained by a chain of inplace on the
# variable to update. In some cases, this will be
# equivalent to operating inplace on the variable to
# update)
# - Remaining variables
updated_inputs = []
for i, f_out in enumerate(fgraph.outputs):
if f_out is candidate_out_var and i in fgraph.update_mapping:
updated_inp_idx = fgraph.update_mapping[i]
updated_inputs.append(fgraph.inputs[updated_inp_idx])
updated_vars = []
vars_from_inplace = []
other_vars = []
for inp_idx in candidate_inputs:
inp = node.inputs[inp_idx]
if inp in updated_inputs:
# the candidate input is the actual updated input
updated_vars.append(inp_idx)
elif (
hasattr(fgraph, "destroy_handler")
and inp.owner
and any(
fgraph.destroy_handler.root_destroyer.get(up_inp, None)
is inp.owner
for up_inp in updated_inputs
)
):
# the candidate input is a variable computed
# inplace on the updated input via a sequence of
# one or more inplace operations
vars_from_inplace.append(inp_idx)
else:
other_vars.append(inp_idx)
sorted_candidate_inputs = (
updated_vars + vars_from_inplace + other_vars
)
for candidate_input in sorted_candidate_inputs:
# remove inputs that don't have the same dtype as the output
if (
node.inputs[candidate_input].type
!= node.outputs[candidate_output].type
):
continue
inplace_pattern = dict(baseline)
inplace_pattern[candidate_output] = candidate_input
try:
if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace(
aes.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scal = op.scalar_op.__class__(
aes.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
)
new_outputs = self.op(new_scal, inplace_pattern)(
*node.inputs, return_list=True
)
new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new_outputs):
prof["nb_call_replace"] += 1
fgraph.replace(
r, new_r, reason="inplace_elemwise_optimizer"
)
nb_change_no_validate += 1
prof["ndim"][candidate_out_var.ndim] += 1
if nb_change_no_validate >= check_each_change:
prof["nb_call_validate"] += 1
fgraph.validate()
chk = fgraph.checkpoint()
nb_change_no_validate = 0
except (ValueError, InconsistencyError) as e:
prof["nb_inconsistent"] += 1
if check_each_change != 1 and not raised_warning:
print(
(
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file=sys.stderr,
)
print(e, file=sys.stderr)
raised_warning = True
fgraph.revert(chk)
continue
candidate_inputs.remove(candidate_input)
node = new_node
baseline = inplace_pattern
break
if nb_change_no_validate > 0:
try:
fgraph.validate()
except Exception:
if not raised_warning:
print(
(
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file=sys.stderr,
)
fgraph.revert(chk)
return prof
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(
f"{' ' * level}{self.__class__.__name__} ({self.op})",
file=stream,
)
return inplace_elemwise_optimizer
inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
compile.optdb.register(
"inplace_elemwise_opt",
inplace_elemwise_optimizer,
"inplace_opt", # for historic reason
"inplace_elemwise_optimizer",
"fast_run",
"inplace",
position=75,
)
def register_useless(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
):
......@@ -585,159 +222,6 @@ def register_specialize_device(
return node_rewriter
def apply_local_dimshuffle_lift(fgraph, var):
"""
lift recursively
"""
if not var.owner:
return var
new = local_dimshuffle_lift.transform(fgraph, var.owner)
if new:
return new[0]
return var
def is_dimshuffle_useless(new_order, input):
"""
Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order.
2 - dimshuffle a broadcastable dimension.
"""
is_useless = True
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [
i
for (i, is_broadcastable) in enumerate(input.type.broadcastable)
if is_broadcastable
] + ["x"]
for i in range(input.type.ndim):
if new_order[i] == i or (
i in all_broadcastable_dims and new_order[i] in all_broadcastable_dims
):
is_useless = True
else:
is_useless = False
break
else:
is_useless = False
return is_useless
@register_canonicalize
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node):
"""
"Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
"""
op = node.op
if not isinstance(op, DimShuffle):
return False
inp = node.inputs[0]
inode = inp.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1):
# Don't use make_node to have tag.test_value set.
new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp)
new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp))
copy_stack_trace(node.outputs[0], new_inputs)
ret = inode.op(*new_inputs, return_list=True)
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order]
inp = inode.inputs[0]
if is_dimshuffle_useless(new_order, inp):
return [inp]
elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(inp.type.broadcastable, new_order)(inp)
ret = apply_local_dimshuffle_lift(fgraph, ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize
@register_specialize
@node_rewriter([DimShuffle])
def local_useless_dimshuffle_makevector(fgraph, node):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
This rewrite is needed in order to clean up after
`local_subtensor_remove_broadcastable_index`, which produces a
not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)`
(i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`.
"""
# The `DimShuffle` should be removing the single broadcastable dimension
if node.op.new_order != ():
return
makevector_out = node.inputs[0]
if (
not makevector_out.owner
or not isinstance(makevector_out.owner.op, MakeVector)
or not makevector_out.broadcastable == (True,)
):
return
assert len(makevector_out.owner.inputs) == 1
return [makevector_out.owner.inputs[0]]
@register_canonicalize
@node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if not isinstance(op, Reshape):
return False
if not (
node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, DimShuffle)
):
return False
new_order = node.inputs[0].owner.op.new_order
inp = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in range(len(new_order_of_nonbroadcast) - 1)
)
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(inp, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize
@register_specialize
@node_rewriter([TensorFromScalar])
......@@ -766,722 +250,6 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s]
class MakeVectorPrinter(Printer):
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
with set_precedence(pstate):
s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
return f"[{', '.join(s)}]"
else:
raise TypeError("Can only print make_vector.")
pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(Feature):
r"""A `Feature` that tracks shape information in a graph.
This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with
`Shape_i` and `MakeVector` `Op`\s.
This `Feature` and its associated rewrites have several goals:
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the
input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute
something just to know how big it will be. Firstly, it is a waste
of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many rewrites refuse to
work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is
present, or else using a conservative default. An Op that
supports shape-lifting should define a infer_shape(self, fgraph, node,
input_shapes) function. The argument input_shapes is a tuple of
tuples... there is an interior tuple for each input to the node.
The tuple has as many elements as dimensions. The element in
position i of tuple j represents the i'th shape component of the
j'th input. The function should return a tuple of tuples. One
output tuple for each node.output. Again, the i'th element of the
j'th output tuple represents the output[j].shape[i] of the
function. If an output is not a TensorType, then None should be
returned instead of a tuple for that output.
For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
for doing size-driven rewrites. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes.
In cases where you cannot figure out the shape, raise a ShapeError.
Notes
-----
Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape
inference is 1) to mark shared inputs as having a particular
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
certain dimensions).
We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in rewrites, use the
``shape_of`` dictionary.
For example:
.. code-block:: python
try:
shape_of = fgraph.shape_feature.shape_of
except AttributeError:
# This can happen when the mode doesn't include the ShapeFeature.
return
shape_of_output_zero = shape_of[node.output[0]]
The ``shape_of_output_zero`` symbol will contain a tuple, whose
elements are either integers or symbolic integers.
TODO: check to see if the symbols are necessarily
non-constant... or are integer literals sometimes Aesara
constants?? That would be confusing.
"""
def get_node_infer_shape(self, node):
try:
shape_infer = node.op.infer_shape
except AttributeError:
shape_infer = self.default_infer_shape
try:
o_shapes = shape_infer(
self.fgraph, node, [self.shape_of[r] for r in node.inputs]
)
except ShapeError:
o_shapes = self.default_infer_shape(
self.fgraph, node, [self.shape_of[r] for r in node.inputs]
)
except NotImplementedError as e:
raise NotImplementedError(
"Code called by infer_shape failed raising a "
"NotImplementedError. Raising NotImplementedError to "
"indicate that a shape cannot be computed is no longer "
"supported, and one should now use ShapeError "
f"instead. The original exception message is: {e}"
).with_traceback(e.__traceback__)
except Exception as e:
msg = (
f"Failed to infer_shape from Op {node.op}.\nInput shapes: "
f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: "
f"{type(e)}\nException message: {str(e)}\nTraceback: {traceback.format_exc()}"
)
if config.on_shape_error == "raise":
raise Exception(msg).with_traceback(e.__traceback__)
else:
_logger.warning(msg)
o_shapes = self.default_infer_shape(
self.fgraph, node, [self.shape_of[r] for r in node.inputs]
)
return o_shapes
def get_shape(self, var, idx):
"""Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly ``shape_of[var][idx]``
as this method should update `shape_of` if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
r = self.shape_of[var][idx]
if (
r.owner
and isinstance(r.owner.op, Shape_i)
and r.owner.inputs[0] not in self.fgraph.variables
):
assert var.owner
node = var.owner
# recur on inputs
for i in node.inputs:
if getattr(i.type, "ndim", None) > 0:
self.get_shape(i, 0)
o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs)
# Only change the variables and dimensions that would introduce
# extra computation
for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out.type, "ndim"):
continue
merged_shps = list(self.shape_of[out])
changed = False
for i in range(out.type.ndim):
n_r = merged_shps[i]
if (
n_r.owner
and isinstance(n_r.owner.op, Shape_i)
and n_r.owner.inputs[0] not in self.fgraph.variables
):
changed = True
merged_shps[i] = new_shps[i]
if changed:
self.set_shape(out, merged_shps, override=True)
r = self.shape_of[var][idx]
return r
def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
if hasattr(r.type, "shape") and r.type.shape[i] is not None:
return constant(r.type.shape[i], dtype="int64")
else:
# Do not call make_node for test_value
s = Shape_i(i)(r)
try:
s = get_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r."""
if not hasattr(r.type, "ndim"):
# This happen for NoneConst.
return None
return tuple(self.shape_ir(i, r) for i in range(r.type.ndim))
def default_infer_shape(self, fgraph, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
"""
rval = []
for r in node.outputs:
try:
rval.append(self.shape_tuple(r))
except AttributeError:
rval.append(None)
return rval
def unpack(self, s_i, var):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
var: the variable that correspond to s_i. This is just for
error reporting.
"""
assert s_i is not None
if s_i == 1:
return self.lscalar_one
if isinstance(s_i, float) and int(s_i) == s_i:
s_i = int(s_i)
if isinstance(s_i, (np.integer, int)) or (
isinstance(s_i, np.ndarray) and s_i.ndim == 0
):
# this shape is a constant
if s_i < 0:
msg = "There is a negative shape in the graph!"
msg += get_variable_trace_string(var)
# The rest of the pipeline don't handle correctly this
# case. So we have 2 choices, stop compilation or
# consider the shape as unknown. As we have more
# chance to give the stack trace here then later, I
# choose that options as it would give better error
# message.
raise AssertionError(msg)
return constant(s_i, dtype="int64")
if isinstance(s_i, (tuple, list)):
# this dimension is the same as many of the inputs
# which tells us that if one of the inputs is known,
# the others all become known.
# TODO: should be implemented in Elemwise, and Dot
#
# worst case, we loop over shape_of and replace things
raise NotImplementedError(s_i)
# s_i is x.shape[i] for some x, we change it to shape_of[x][i]
if (
s_i.owner
and isinstance(s_i.owner.op, Subtensor)
and s_i.owner.inputs[0].owner
and isinstance(s_i.owner.inputs[0].owner.op, Shape)
):
assert s_i.type.ndim == 0
assert len(s_i.owner.op.idx_list) == 1
# The current Subtensor always put constant index in the graph.
# This was not True in the past. So call the Subtensor function
# that will return the right index.
idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list)
assert len(idx) == 1
idx = idx[0]
try:
i = get_scalar_constant_value(idx)
except NotScalarConstantError:
pass
else:
# Executed only if no exception was raised
x = s_i.owner.inputs[0].owner.inputs[0]
# x should already have been imported, and should be in shape_of.
s_i = self.shape_of[x][i]
if s_i.type.dtype in integer_dtypes:
if getattr(s_i.type, "ndim", 0):
raise TypeError("Shape element must be scalar", s_i)
return s_i
else:
raise TypeError(
"Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None)
)
def set_shape(self, r, s, override=False):
"""Assign the shape `s` to previously un-shaped variable `r`.
Parameters
----------
r : a variable
s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
"""
if not override:
assert r not in self.shape_of, "r already in shape_of"
if s is None:
self.shape_of[r] = s
else:
if not isinstance(s, (tuple, list)):
raise TypeError("shapes must be tuple/list", (r, s))
if r.type.ndim != len(s):
sio = StringIO()
aesara.printing.debugprint(r, file=sio, print_type=True)
raise AssertionError(
f"Something inferred a shape with {len(s)} dimensions "
f"for a variable with {int(r.type.ndim)} dimensions"
f" for the variable:\n{sio.getvalue()}"
)
shape_vars = []
for i in range(r.type.ndim):
if hasattr(r.type, "shape") and r.type.shape[i] is not None:
shape_vars.append(constant(r.type.shape[i], dtype="int64"))
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(shape_vars)
for sv in shape_vars:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def update_shape(self, r, other_r):
"""Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions.
"""
# other_r should already have a shape
assert other_r in self.shape_of, ("other_r not in shape_of", other_r)
other_shape = self.shape_of[other_r]
# If other_shape has no information, call is pointless.
if other_shape is None:
return
if r in self.shape_of:
r_shape = self.shape_of[r]
else:
# If no info is known on r's shape, use other_shape
self.set_shape(r, other_shape)
return
if (
other_r.owner
and r.owner
and other_r.owner.inputs == r.owner.inputs
and other_r.owner.op == r.owner.op
):
# We are doing a merge, so the two shape graphs will be the
# same. This is only done so that we call `ancestors` less
# frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
for i, ps in enumerate(other_shape):
if r_shape is None and other_shape:
merged_shape.append(other_shape[i])
elif (
ps.owner
and isinstance(getattr(ps.owner, "op", None), Shape_i)
and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r)
):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape.append(r_shape[i])
elif isinstance(r_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(r_shape[i])
elif isinstance(other_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(other_shape[i])
elif other_shape[i] == r_shape[i]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i])
elif r_shape[i] in ancestors([other_shape[i]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
# is strictly more complex. Such a substitution could also lead
# to cycles: if (in the future) r_shape[i] gets replaced by an
# expression of other_shape[i], other_shape[i] may end up
# depending on itself.
merged_shape.append(r_shape[i])
else:
merged_shape.append(other_shape[i])
assert all(
(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True)
)
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def set_shape_i(self, r, i, s_i):
"""Replace element i of shape_of[r] by s_i"""
assert r in self.shape_of
prev_shape = self.shape_of[r]
# prev_shape is a tuple, so we cannot change it inplace,
# so we build another one.
new_shape = []
for j, s_j in enumerate(prev_shape):
if j == i:
new_shape.append(self.unpack(s_i, r))
else:
new_shape.append(s_j)
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[idx]
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim)
)
self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def init_r(self, r):
"""Register r's shape in the shape_of dictionary."""
if r not in self.shape_of:
self.set_shape(r, self.shape_tuple(r))
def make_vector_shape(self, r):
return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64")
def on_attach(self, fgraph):
if hasattr(fgraph, "shape_feature"):
raise AlreadyThere("This FunctionGraph already has a ShapeFeature")
if hasattr(self, "fgraph") and self.fgraph != fgraph:
raise Exception("This ShapeFeature is already attached to a graph")
self.fgraph = fgraph
fgraph.shape_feature = self
# Must be local to the object as otherwise we reuse the same
# variable for multiple fgraph!
self.lscalar_one = constant(1, dtype="int64")
assert self.lscalar_one.type.dtype == "int64"
self.fgraph = fgraph
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.shape_of = {}
# Variable ->
self.scheduled = {}
# shape var -> graph v
self.shape_of_reverse_index = {}
for node in fgraph.toposort():
self.on_import(fgraph, node, reason="on_attach")
def on_detach(self, fgraph):
self.shape_of = {}
self.scheduled = {}
self.shape_of_reverse_index = {}
self.fgraph = None
del fgraph.shape_feature
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
assert r in self.shape_of
return
for i, r in enumerate(node.inputs):
# make sure we have shapes for the inputs
self.init_r(r)
o_shapes = self.get_node_infer_shape(node)
# this is packed information
# an element of o_shapes is either None or a tuple
# elements of the tuple can be either strings, or ints
if len(o_shapes) != len(node.outputs):
raise Exception(
(
f'The infer_shape method for the Op "{node.op}" returned a list '
f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} "
f" != len(node.outputs) = {len(node.outputs)}"
)
)
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` rewrite does not fail.
for sh_idx, sh in enumerate(o_shapes):
if sh is None:
continue
if not isinstance(sh, (list, tuple)):
raise ValueError(
f"infer_shape of {node} didn't return a list of"
f" list. It returned '{o_shapes}'"
)
new_shape = []
for i, d in enumerate(sh):
# Note: we ignore any shape element that is not typed (i.e.,
# does not have a 'dtype' attribute). This means there may
# still remain int elements that are int32 on 32-bit platforms,
# but this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix.
if getattr(d, "dtype", "int64") != "int64":
assert d.dtype in discrete_dtypes, (node, d.dtype)
assert str(d.dtype) != "uint64", node
new_shape += sh[len(new_shape) : i + 1]
if isinstance(d, Constant):
casted_d = constant(d.data, dtype="int64")
else:
casted_d = cast(d, "int64")
new_shape[i] = casted_d
if new_shape:
# We replace the shape with wrong dtype by the one with
# 'int64'.
new_shape += sh[len(new_shape) :]
o_shapes[sh_idx] = tuple(new_shape)
for r, s in zip(node.outputs, o_shapes):
self.set_shape(r, s)
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if new_r not in self.shape_of:
# It happen that the fgraph didn't called on_import for some
# new_r. This happen when new_r don't have an
# owner(i.e. it is a constant or an input of the graph)
# update_shape suppose that r and new_r are in shape_of.
self.init_r(new_r)
# This tells us that r and new_r must have the same shape if
# we didn't know that the shapes are related, now we do.
self.update_shape(new_r, r)
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
# In case 1, if r has a shape_i client, we will want to
# replace the shape_i of r with the shape of new_r. Say that
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
for (shpnode, idx) in fgraph.clients[r] + [(node, i)]:
if isinstance(getattr(shpnode, "op", None), Shape_i):
idx = shpnode.op.i
repl = self.shape_of[new_r][idx]
if repl.owner is shpnode:
# This mean the replacement shape object is
# exactly the same as the current shape object. So
# no need for replacement.
continue
if (
repl.owner
and repl.owner.inputs[0] is shpnode.inputs[0]
and isinstance(repl.owner.op, Shape_i)
and repl.owner.op.i == shpnode.op.i
):
# The replacement is a shape_i of the same
# input. So no need to do this equivalent
# replacement.
continue
if shpnode.outputs[0] in ancestors([repl]):
raise InconsistencyError(
"This substitution would insert a cycle in the graph:"
f"node: {node}, i: {i}, r: {r}, new_r: {new_r}"
)
self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update,
# then we should cancel it.
unscheduled = [k for k, v in self.scheduled.items() if v == r]
for k in unscheduled:
del self.scheduled[k]
# In either case, r could be in shape_of.values(), that is, r itself
# is the shape of something. In that case, we want to update
# the value in shape_of, to keep it up-to-date.
for v in self.shape_of_reverse_index.get(r, []):
# The reverse index is only approximate. It is not updated on
# deletion of variables, or on change_input so it might be the
# case that there are a few extra `v`'s in it that no longer have
# a shape of r or possibly have been deleted from shape_of
# entirely. The important thing is that it permits to recall
# all variables with r in their shape.
for ii, svi in enumerate(self.shape_of.get(v, [])):
if svi == r:
self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set()
def same_shape(
self,
x: Variable,
y: Variable,
dim_x: Optional[int] = None,
dim_y: Optional[int] = None,
) -> bool:
"""Return ``True`` if `x` and `y` have the same shape.
Parameters
==========
x
The `Variable` for which its shape is to be compared with `y`'s shape.
y
The `Variable` for which its shape is to be compared with `x`'s shape.
dim_x
If non ``None``, compare only the dimension of `x` equal to
`dim_x`.
dim_y
If non ``None``, compare only the dimension of `y` equal to
`dim_y`.
"""
sx = self.shape_of[x]
sy = self.shape_of[y]
if sx is None or sy is None:
return False
if dim_x is not None:
sx = [sx[dim_x]]
if dim_y is not None:
sy = [sy[dim_y]]
if len(sx) != len(sy):
return False
# Canonicalize the graphs so that comparisons are reasonable
# TODO FIXME: This should *not* need to be performed manually here.
# Instead, the shape information in `self.shape_of` should be operated
# upon alongside all the other elements in a `FunctionGraph` (e.g. as
# if `self.shape_of.values()` were additional outputs).
shapes_fg = FunctionGraph(
outputs=sx + sy,
# features=[self],
clone=True,
# copy_inputs=False,
)
from aesara.graph.rewriting.utils import rewrite_graph
canon_shapes = rewrite_graph(
shapes_fg, custom_rewrite=topo_constant_folding
).outputs
sx = canon_shapes[: len(sx)]
sy = canon_shapes[len(sx) :]
for dx, dy in zip(sx, sy):
if not equal_computations([dx], [dy]):
return False
return True
def clone(self):
return type(self)()
class ShapeOptimizer(GraphRewriter):
"""Rewriter that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature())
def apply(self, fgraph):
pass
class UnShapeOptimizer(GraphRewriter):
"""Rewriter that removes `ShapeFeature` as a feature."""
def apply(self, fgraph):
for feature in fgraph._features:
if isinstance(feature, ShapeFeature):
fgraph.remove_feature(feature)
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
aesara.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_specialize("local_alloc_elemwise")
@node_rewriter([Elemwise])
def local_elemwise_alloc(fgraph, node):
......@@ -1815,43 +583,6 @@ compile.optdb.register(
)
@register_specialize
@register_canonicalize
@node_rewriter([Shape])
def local_shape_to_shape_i(fgraph, node):
if isinstance(node.op, Shape):
if not hasattr(fgraph, "shape_feature"):
return
shape_feature = fgraph.shape_feature
ret = shape_feature.make_vector_shape(node.inputs[0])
# We need to copy over stack trace from input to output
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_specialize
@register_canonicalize
@node_rewriter([Shape_i])
def local_track_shape_i(fgraph, node):
if not isinstance(node.op, Shape_i):
return False
try:
shape_feature = fgraph.shape_feature
except AttributeError:
return False
if node not in shape_feature.scheduled:
return False
# Don't unschedule node as it could be reinserted in the
# fgraph as we don't change it in the shapefeature internal
# structure.
replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]]
@register_useless
@register_canonicalize("fast_compile")
@register_specialize
......@@ -2130,153 +861,6 @@ compile.optdb["useless"].register(
)
@register_canonicalize
@node_rewriter([Elemwise])
def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
if len(node.outputs) > 1:
return
try:
shape_i = fgraph.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
aes.upgrade_to_float,
aes.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
if new_inputs != node.inputs:
rval = [node.op(*new_inputs)]
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
# Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval)
return rval
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
TODO: Implement equivalent rewrite for SpecifyShape
"""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable:
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return [x]
else:
# Keep the flags that modify something
new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
if new_axes == node.op.axes:
# All flags are useful
return None
else:
r = unbroadcast(x, *new_axes)
# Copy over stacktrace from previous output
copy_stack_trace(node.outputs, r)
return [r]
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_unbroadcast_lift(fgraph, node):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op = node.op
if not isinstance(op, Unbroadcast):
return False
inp = node.inputs[0]
inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
if len(fgraph.clients.get(inp, ())) == 1:
unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
copy_stack_trace(node.outputs, unbroadcasted)
rval = inode.op.make_node(unbroadcasted).outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
if inode and isinstance(inode.op, Unbroadcast):
# Merge axis of each unbroadcast
axis = tuple(set(inode.op.axes).union(set(op.axes)))
iinput = inode.inputs[0]
rval = [unbroadcast(iinput, *axis)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
@register_specialize
@register_canonicalize
@register_useless
......@@ -2412,7 +996,7 @@ def local_useless_switch(fgraph, node):
if not isinstance(node.op.scalar_op, aes.Switch):
return False
shape_feature: Optional[ShapeFeature] = getattr(fgraph, "shape_feature", None)
shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None)
if shape_feature is None:
return False
......@@ -2537,225 +1121,6 @@ def local_useless_split(fgraph, node):
return [out2]
def local_reshape_chain(op):
@node_rewriter([op])
def f(fgraph, node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if not check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
return f
register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
@register_useless
@register_canonicalize
@register_stabilize
@node_rewriter([Reshape])
def local_useless_reshape(fgraph, node):
"""
Remove two kinds of useless reshape.
Remove Reshape when both the input and output have a single dimension.
Remove Reshape when reshaping to the shape of the input.
"""
op = node.op
if not isinstance(op, Reshape):
return False
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if inp.ndim != output.ndim:
return False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if inp.ndim == 1 and output.ndim == 1 and inp.broadcastable == output.broadcastable:
return [inp]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == inp:
return [inp]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None)
nb_m1 = 0
shape_match = [False] * inp.ndim
for dim in range(inp.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Shape_i)
and outshp_i.owner.op.i == dim
and outshp_i.owner.inputs[0] == inp
):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim
):
subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == inp:
shape_match[dim] = True
continue
# Match 1 if input.broadcastable[dim] is True
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.broadcastable[dim] and cst_outshp_i == 1:
shape_match[dim] = True
continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1)
):
shape_match[dim] = True
continue
if all(shape_match) and nb_m1 <= 1:
return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False
@register_canonicalize
@node_rewriter([Reshape])
def local_reshape_to_dimshuffle(fgraph, node):
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
The goal is to avoid using reshape to add or remove broadcastable
dimensions, but use dimshuffle instead, so dimshuffles can cancel out
or be removed later on.
For example:
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
"""
op = node.op
if not isinstance(op, Reshape):
return False
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
dimshuffle_new_order = []
new_output_shape = []
index = 0 # index over the output of the new reshape
for i in range(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = extract_constant(
output_shape[i], only_process_constants=False, elemwise=False
)
if dim == 1:
dimshuffle_new_order.append("x")
else:
dimshuffle_new_order.append(index)
new_output_shape.append(dim)
index = index + 1
if index != output.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
copy_stack_trace(output, new_node)
return new_node
@register_canonicalize
@register_stabilize
@node_rewriter([Reshape])
def local_reshape_lift(fgraph, node):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Notes
-----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
"""
if (
isinstance(node.op, Reshape)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Elemwise)
and len(node.inputs[0].owner.inputs) == 1
):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy stacktrace from previous Reshape op, as an error in new
# Reshape op could only have been caused by old one.
copy_stack_trace(node.outputs, r)
e = node.inputs[0].owner.op(r)
# Copy stacktrace from both previous Reshape and UnaryElemwise op
# because an error in new cg could have been caused by either ops.
copy_stack_trace(node.outputs + node.inputs, e)
return [e]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@node_rewriter(None)
def constant_folding(fgraph, node):
......@@ -2817,431 +1182,6 @@ register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` `Op`\s.
Parameters
----------
op_class : type
`Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature ``(node, *args)`` that constructs an
`op_class` instance (e.g. ``op_class(*args)``).
"""
if maker is None:
def maker(node, scalar_op):
return op_class(scalar_op)
def local_fuse(fgraph, node):
r"""Fuse `Elemwise` `Op`\s in a node.
As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the
same shape.
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by Aesara itself.
"""
# TODO: use broadcast flag?
# TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
# TODO: use malloc and copy to transfer arguments that don't
# fit within the parameter space of 256 bytes
#
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
# TODO: Use Composite to combine Elemwise and Reduce
# operations. We have to loop over the data anyway... might
# as well sum it up while we're at it (this can be trickier
# than i'm making it seound here. The data-traversal should be
# done contiguously, and the summing-up might not be easy or
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if type(node.op) is not op_class:
return False
if len(node.outputs) > 1:
# We don't support fusion for nodes with multiple outputs.
return
inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
s_g = []
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input = max_input_fct(node)
# The number of inputs to the new fused op if we do not fuse more
# inputs.
new_nb_input = len(node.inputs)
# Did we fuse something?
# Needed as we can fuse unary op that don't change the number of
# inputs.
# And there is a case where the inputs are the same as the current
# node. That won't change the number of inputs of the new op.
fused = False
for i in node.inputs:
scalar_node: Optional[Apply] = None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input = []
# Same as tmp_input, but for scalars.
tmp_scalar = []
# We should not check the number of inputs here
# As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if (
i.owner
and isinstance(i.owner.op, op_class)
and len({n for n, idx in fgraph.clients[i]}) == 1
and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable
):
try:
tmp_s_input = []
# we should not put duplicate input into s_inputs and inputs
for ii in i.owner.inputs:
if ii in inputs:
tmp_s_input.append(s_inputs[inputs.index(ii)])
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try:
tv = get_test_value(ii)
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp.tag.test_value = tv.item()
except (TestValueError, ValueError):
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
if config.compute_test_value_opt != "off":
# This is required because `Op.make_node` won't do it
compute_test_value(scalar_node)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i.owner.op.scalar_op.c_code(
scalar_node,
"test_presence_of_c_code",
["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
_logger.warning(
(
"Rewrite warning: "
f"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
scalar_node = None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_ = new_nb_input + len(tmp_input) - 1
# If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for x in tmp_input:
if x in node.inputs:
new_nb_input_ -= 1
if scalar_node and (new_nb_input_ <= max_nb_input):
fused = True
new_nb_input = new_nb_input_
inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar)
s_g.extend(scalar_node.outputs)
else:
# We must support the case where the same variable appears many
# times within the inputs
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else:
s = aes.get_scalar_type(i.type.dtype).make_variable()
if config.compute_test_value_opt != "off":
try:
v = get_test_value(i)
# See the zero-dimensional test value situation
# described above.
s.tag.test_value = v.item()
except (TestValueError, ValueError):
pass
inputs.append(i)
s_inputs.append(s)
s_g.append(s)
if not fused:
return False
if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
# TODO FIXME: This shouldn't be a generic `Exception`
raise Exception(
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out = node.op.scalar_op(*s_g, return_list=True)
try:
s_new_out[0].owner.op.c_code(
s_new_out[0].owner,
"test_presence_of_c_code",
["x" for x in s_g],
["z" for x in s_new_out],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
name = str(s_new_out[0].owner.op)
_logger.warning(
(
"Rewrite warning: "
f"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
return False
# create the composite op.
composite_op = aes.Composite(s_inputs, s_new_out)
# create the new node.
# Do not call make_node to have test_value
new_node = maker(node, composite_op)(*inputs).owner
assert len(new_node.outputs) == 1
assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype
if len(new_node.inputs) > max_nb_input:
_logger.warning(
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return False
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while True:
ret = local_fuse(fgraph, new_node)
if ret is not False and ret is not None:
assert len(ret) == len(new_node.outputs)
assert len(ret) == 1
new_node = ret[0].owner
else:
break
return new_node.outputs
return local_fuse
def elemwise_max_input_fct(node):
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs.
if not config.cxx:
return 31
return 1024
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
class FusionOptimizer(GraphRewriter):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
"""
def __init__(self, node_rewriter):
super().__init__()
self.node_rewriter = node_rewriter
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
did_something = True
nb_iter = 0
nb_replacement = 0
nb_inconsistency_replace = 0
time_toposort = 0
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
while did_something:
t0 = time.time()
nodelist = list(fgraph.toposort())
time_toposort += time.time() - t0
nodelist.reverse()
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes:
new_outputs = self.node_rewriter(fgraph, node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
fgraph.replace_all_validate(
list(zip(node.outputs, new_outputs)),
reason=self.__class__.__name__,
)
did_something = True
nb_replacement += 1
except InconsistencyError:
nb_inconsistency_replace += 1
nb_iter += 1
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before
callbacks_time = {}
for k, v in fgraph.execute_callbacks_times.items():
if k in callbacks_before:
callbacks_time[k] = v - callbacks_before[k]
else:
callbacks_time[k] = v
else:
validate_time = None
callback_time = None
callbacks_time = {}
return (
self,
nb_iter,
nb_replacement,
nb_inconsistency_replace,
validate_time,
callback_time,
callbacks_time,
time_toposort,
)
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, file=stream)
print(blanc, " nb_iter", prof[1], file=stream)
print(blanc, " nb_replacement", prof[2], file=stream)
print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
print(blanc, " validate_time", prof[4], file=stream)
print(blanc, " callback_time", prof[5], file=stream)
if prof[5] is not None and prof[5] > 1:
print(blanc, " callbacks_time", file=stream)
for i in sorted(prof[6].items(), key=lambda a: a[1])[::-1]:
if i[1] > 0:
print(blanc, " ", i)
print(blanc, " time_toposort", prof[7], file=stream)
if config.tensor__local_elemwise_fusion:
_logger.debug("Enabling Elemwise fusion rewriters in fast_run")
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fast_run",
"fusion",
position=1,
)
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
else:
_logger.debug("Not enabling Elemwise fusion rewriters in fast_run")
compile.optdb.register(
"elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
@register_canonicalize
@node_rewriter([Elemwise])
def local_useless_composite(fgraph, node):
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite
):
return
comp = node.op.scalar_op
idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
if len(idx) < len(node.outputs):
new_outputs = [comp.outputs[i] for i in idx]
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
return dict(zip([node.outputs[i] for i in idx], e))
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@node_rewriter(None)
......@@ -3325,240 +1265,32 @@ def local_useless_topk(fgraph, node):
return {old_output: new_output}
@register_useless
@register_canonicalize
@node_rewriter([SpecifyShape])
def local_merge_consecutive_specify_shape(fgraph, node):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
"""
if not isinstance(node.op, SpecifyShape):
return False
obj = node.inputs[0]
if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)):
return False
inner_obj, *shape = obj.owner.inputs
for dim, sh in enumerate(node.inputs[1:]):
if not NoneConst.equals(sh):
shape[dim] = sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
# the same.
return [specify_shape(inner_obj, shape)]
@register_useless
@register_canonicalize
@node_rewriter([Shape])
def local_Shape_of_SpecifyShape(fgraph, node):
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
if not isinstance(node.op, Shape):
return False
specified_shape = node.inputs[0]
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
return False
x, *shape = specified_shape.owner.inputs
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
return [stack(shape).astype(np.int64)]
@register_useless
@register_canonicalize
@node_rewriter([Shape_i])
def local_Shape_i_of_broadcastable(fgraph, node):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
if not isinstance(node.op, Shape_i):
return False
shape_arg = node.inputs[0]
if not isinstance(shape_arg.type, TensorType):
return False
if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False
def import_ShapeFeature():
from aesara.tensor.rewriting.shape import ShapeFeature
if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False
return ShapeFeature
uniqued_var = node.inputs[0]
if uniqued_var.ndim != 0:
return False
DEPRECATED_NAMES = {
"ShapeFeature": (
"`ShapeFeature` is now located in `aesara.tensor.rewriting.shape`.",
import_ShapeFeature,
),
}
old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
Adapted from https://stackoverflow.com/a/55139609/3006474.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
alloc_var = node.inputs[0]
if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False
alloced_var, *alloc_shape = alloc_var.owner.inputs
new_unique, *_ = node.op.make_node(alloced_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
repeat_var = node.inputs[0]
if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False
repeated_var, *repeat_shape = repeat_var.owner.inputs
new_unique, *_ = node.op.make_node(repeated_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
second_var = node.inputs[0]
if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False
shape_var, seconded_var = second_var.owner.inputs
new_unique, *_ = node.op.make_node(seconded_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
from warnings import warn
bcast_shape = node.inputs[1:]
res = DEPRECATED_NAMES.get(name)
if res:
msg, fn = res
warn(msg, DeprecationWarning, stacklevel=2)
return fn()
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
raise AttributeError(f"module {__name__} has no attribute {name}")
import sys
import time
from collections import defaultdict
from typing import Optional
from warnings import warn
import aesara
import aesara.scalar.basic as aes
from aesara import compile
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, io_toposort
from aesara.graph.features import ReplaceValidate
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter
from aesara.graph.rewriting.db import SequenceDB
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
from aesara.tensor.shape import shape_padleft
from aesara.tensor.var import TensorConstant
class InplaceElemwiseOptimizer(GraphRewriter):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
def __init__(self, OP):
self.op = OP
def add_requirements(self, fgraph):
from aesara.graph.destroyhandler import DestroyHandler
fgraph.attach_feature(DestroyHandler())
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, prof["opt"].op, file=stream)
for k in [
"node_before",
"nb_call_replace",
"nb_call_validate",
"nb_inconsistent",
]:
print(blanc, k, prof[k], file=stream)
ndim = prof["ndim"]
if ndim:
print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim.keys()):
print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph):
r"""
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Examples
--------
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
# We should not validate too often as this takes too much time to
# execute!
# It is the _dfs_toposort() fct in aesara/graph/destroyhandler.py
# that takes so much time.
# Should we try to use another lib that does toposort?
# igraph: http://igraph.sourceforge.net/
# networkx: https://networkx.lanl.gov/
# Should we try to use cython?
# Compiling only that fct is not enough, should we try to add the
# deque class too?
# And init the deque and other list to an upper bound number of
# elements?
# Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic
#
# The next longest rewriter is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
# We execute `validate` after this number of change.
prof = {
"opt": self,
"node_before": len(fgraph.apply_nodes),
"nb_call_replace": 0,
"nb_call_validate": 0,
"nb_inconsistent": 0,
"ndim": defaultdict(lambda: 0),
}
check_each_change = config.tensor__insert_inplace_optimizer_validate_nb
if check_each_change == -1:
if len(fgraph.apply_nodes) > 500:
check_each_change = 10
else:
check_each_change = 1
nb_change_no_validate = 0
chk = fgraph.checkpoint()
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
else:
update_outs = []
protected_inputs = [
f.protected
for f in fgraph._features
if isinstance(f, aesara.compile.function.types.Supervisor)
]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
if not isinstance(op, self.op):
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
if (
check_each_change != 1
and
# If multiple outputs, they must all have the same size,
# so only check the first.
getattr(node.outputs[0].type, "ndim", -1) == 0
):
continue
if op.inplace_pattern:
# Maybe this isn't needed anymore, but I don't want to
# rish regression now. This case only happen if the
# original node add already some inplace patter and we
# still try to add more pattern.
baseline = op.inplace_pattern
candidate_outputs = [
i for i in range(len(node.outputs)) if i not in baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
# inplace target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if i not in baseline.values()
and not isinstance(node.inputs[i], Constant)
and
# the next line should not be costly most of the time.
not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
else:
baseline = []
candidate_outputs = list(range(len(node.outputs)))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [
i
for i in range(len(node.inputs))
if not isinstance(node.inputs[i], Constant)
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
verbose = False
raised_warning = not verbose
for candidate_output in candidate_outputs:
# If the output of the node can be established as an update
# output of the fgraph, visit the candidate_inputs in an order
# that will improve the chances of making the node operate
# inplace on the input it's meant to update
candidate_out_var = node.outputs[candidate_output]
sorted_candidate_inputs = candidate_inputs
if candidate_out_var in update_outs:
# The candidate output is an update. Sort the
# variables in candidate_inputs in the following order:
# - Vars corresponding to the actual updated input
# (best case scenario is for the node that procudes
# an update to operate inplace on the variable to
# update)
# - Vars computed inplace on the updates input (second
# best scenario if for the node to work inplace on
# a variable obtained by a chain of inplace on the
# variable to update. In some cases, this will be
# equivalent to operating inplace on the variable to
# update)
# - Remaining variables
updated_inputs = []
for i, f_out in enumerate(fgraph.outputs):
if f_out is candidate_out_var and i in fgraph.update_mapping:
updated_inp_idx = fgraph.update_mapping[i]
updated_inputs.append(fgraph.inputs[updated_inp_idx])
updated_vars = []
vars_from_inplace = []
other_vars = []
for inp_idx in candidate_inputs:
inp = node.inputs[inp_idx]
if inp in updated_inputs:
# the candidate input is the actual updated input
updated_vars.append(inp_idx)
elif (
hasattr(fgraph, "destroy_handler")
and inp.owner
and any(
fgraph.destroy_handler.root_destroyer.get(up_inp, None)
is inp.owner
for up_inp in updated_inputs
)
):
# the candidate input is a variable computed
# inplace on the updated input via a sequence of
# one or more inplace operations
vars_from_inplace.append(inp_idx)
else:
other_vars.append(inp_idx)
sorted_candidate_inputs = (
updated_vars + vars_from_inplace + other_vars
)
for candidate_input in sorted_candidate_inputs:
# remove inputs that don't have the same dtype as the output
if (
node.inputs[candidate_input].type
!= node.outputs[candidate_output].type
):
continue
inplace_pattern = dict(baseline)
inplace_pattern[candidate_output] = candidate_input
try:
if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace(
aes.transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scal = op.scalar_op.__class__(
aes.transfer_type(
*[
inplace_pattern.get(i, None)
for i in range(len(node.outputs))
]
)
)
new_outputs = self.op(new_scal, inplace_pattern)(
*node.inputs, return_list=True
)
new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new_outputs):
prof["nb_call_replace"] += 1
fgraph.replace(
r, new_r, reason="inplace_elemwise_optimizer"
)
nb_change_no_validate += 1
prof["ndim"][candidate_out_var.ndim] += 1
if nb_change_no_validate >= check_each_change:
prof["nb_call_validate"] += 1
fgraph.validate()
chk = fgraph.checkpoint()
nb_change_no_validate = 0
except (ValueError, InconsistencyError) as e:
prof["nb_inconsistent"] += 1
if check_each_change != 1 and not raised_warning:
print(
(
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file=sys.stderr,
)
print(e, file=sys.stderr)
raised_warning = True
fgraph.revert(chk)
continue
candidate_inputs.remove(candidate_input)
node = new_node
baseline = inplace_pattern
break
if nb_change_no_validate > 0:
try:
fgraph.validate()
except Exception:
if not raised_warning:
print(
(
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file=sys.stderr,
)
fgraph.revert(chk)
return prof
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(
f"{' ' * level}{self.__class__.__name__} ({self.op})",
file=stream,
)
return inplace_elemwise_optimizer
inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
compile.optdb.register( # type: ignore
"inplace_elemwise_opt",
inplace_elemwise_optimizer,
"inplace_opt", # for historic reason
"inplace_elemwise_optimizer",
"fast_run",
"inplace",
position=75,
)
def apply_local_dimshuffle_lift(fgraph, var):
"""
lift recursively
"""
if not var.owner:
return var
new = local_dimshuffle_lift.transform(fgraph, var.owner)
if new:
return new[0]
return var
def is_dimshuffle_useless(new_order, input):
"""
Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order.
2 - dimshuffle a broadcastable dimension.
"""
is_useless = True
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [
i
for (i, is_broadcastable) in enumerate(input.type.broadcastable)
if is_broadcastable
] + ["x"]
for i in range(input.type.ndim):
if new_order[i] == i or (
i in all_broadcastable_dims and new_order[i] in all_broadcastable_dims
):
is_useless = True
else:
is_useless = False
break
else:
is_useless = False
return is_useless
@register_canonicalize
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node):
"""
"Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
"""
op = node.op
if not isinstance(op, DimShuffle):
return False
inp = node.inputs[0]
inode = inp.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1):
# Don't use make_node to have tag.test_value set.
new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp)
new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp))
copy_stack_trace(node.outputs[0], new_inputs)
ret = inode.op(*new_inputs, return_list=True)
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order]
inp = inode.inputs[0]
if is_dimshuffle_useless(new_order, inp):
return [inp]
elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(inp.type.broadcastable, new_order)(inp)
ret = apply_local_dimshuffle_lift(fgraph, ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize
@register_specialize
@node_rewriter([DimShuffle])
def local_useless_dimshuffle_makevector(fgraph, node):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
This rewrite is needed in order to clean up after
`local_subtensor_remove_broadcastable_index`, which produces a
not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)`
(i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`.
"""
# The `DimShuffle` should be removing the single broadcastable dimension
if node.op.new_order != ():
return
makevector_out = node.inputs[0]
if (
not makevector_out.owner
or not isinstance(makevector_out.owner.op, MakeVector)
or not makevector_out.broadcastable == (True,)
):
return
assert len(makevector_out.owner.inputs) == 1
return [makevector_out.owner.inputs[0]]
@register_canonicalize
@node_rewriter([Elemwise])
def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
if len(node.outputs) > 1:
return
try:
shape_i = fgraph.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
aes.upgrade_to_float,
aes.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
if new_inputs != node.inputs:
rval = [node.op(*new_inputs)]
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
# Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval)
return rval
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` `Op`\s.
Parameters
----------
op_class : type
`Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature ``(node, *args)`` that constructs an
`op_class` instance (e.g. ``op_class(*args)``).
"""
if maker is None:
def maker(node, scalar_op):
return op_class(scalar_op)
def local_fuse(fgraph, node):
r"""Fuse `Elemwise` `Op`\s in a node.
As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the
same shape.
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by Aesara itself.
"""
# TODO: use broadcast flag?
# TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
# TODO: use malloc and copy to transfer arguments that don't
# fit within the parameter space of 256 bytes
#
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
# TODO: Use Composite to combine Elemwise and Reduce
# operations. We have to loop over the data anyway... might
# as well sum it up while we're at it (this can be trickier
# than i'm making it seound here. The data-traversal should be
# done contiguously, and the summing-up might not be easy or
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if type(node.op) is not op_class:
return False
if len(node.outputs) > 1:
# We don't support fusion for nodes with multiple outputs.
return
inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
s_g = []
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input = max_input_fct(node)
# The number of inputs to the new fused op if we do not fuse more
# inputs.
new_nb_input = len(node.inputs)
# Did we fuse something?
# Needed as we can fuse unary op that don't change the number of
# inputs.
# And there is a case where the inputs are the same as the current
# node. That won't change the number of inputs of the new op.
fused = False
for i in node.inputs:
scalar_node: Optional[Apply] = None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input = []
# Same as tmp_input, but for scalars.
tmp_scalar = []
# We should not check the number of inputs here
# As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if (
i.owner
and isinstance(i.owner.op, op_class)
and len({n for n, idx in fgraph.clients[i]}) == 1
and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable
):
try:
tmp_s_input = []
# we should not put duplicate input into s_inputs and inputs
for ii in i.owner.inputs:
if ii in inputs:
tmp_s_input.append(s_inputs[inputs.index(ii)])
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try:
tv = get_test_value(ii)
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp.tag.test_value = tv.item()
except (TestValueError, ValueError):
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
if config.compute_test_value_opt != "off":
# This is required because `Op.make_node` won't do it
compute_test_value(scalar_node)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i.owner.op.scalar_op.c_code(
scalar_node,
"test_presence_of_c_code",
["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
warn(
(
"Rewrite warning: "
f"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
scalar_node = None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_ = new_nb_input + len(tmp_input) - 1
# If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for x in tmp_input:
if x in node.inputs:
new_nb_input_ -= 1
if scalar_node and (new_nb_input_ <= max_nb_input):
fused = True
new_nb_input = new_nb_input_
inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar)
s_g.extend(scalar_node.outputs)
else:
# We must support the case where the same variable appears many
# times within the inputs
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else:
s = aes.get_scalar_type(i.type.dtype).make_variable()
if config.compute_test_value_opt != "off":
try:
v = get_test_value(i)
# See the zero-dimensional test value situation
# described above.
s.tag.test_value = v.item()
except (TestValueError, ValueError):
pass
inputs.append(i)
s_inputs.append(s)
s_g.append(s)
if not fused:
return False
if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
# TODO FIXME: This shouldn't be a generic `Exception`
raise Exception(
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out = node.op.scalar_op(*s_g, return_list=True)
try:
s_new_out[0].owner.op.c_code(
s_new_out[0].owner,
"test_presence_of_c_code",
["x" for x in s_g],
["z" for x in s_new_out],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
name = str(s_new_out[0].owner.op)
warn(
(
"Rewrite warning: "
f"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
return False
# create the composite op.
composite_op = aes.Composite(s_inputs, s_new_out)
# create the new node.
# Do not call make_node to have test_value
new_node = maker(node, composite_op)(*inputs).owner
assert len(new_node.outputs) == 1
assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype
if len(new_node.inputs) > max_nb_input:
warn(
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return False
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while True:
ret = local_fuse(fgraph, new_node)
if ret is not False and ret is not None:
assert len(ret) == len(new_node.outputs)
assert len(ret) == 1
new_node = ret[0].owner
else:
break
return new_node.outputs
return local_fuse
def elemwise_max_input_fct(node):
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs.
if not config.cxx:
return 31
return 1024
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
class FusionOptimizer(GraphRewriter):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
"""
def __init__(self, node_rewriter):
super().__init__()
self.node_rewriter = node_rewriter
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
did_something = True
nb_iter = 0
nb_replacement = 0
nb_inconsistency_replace = 0
time_toposort = 0
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
while did_something:
t0 = time.time()
nodelist = list(fgraph.toposort())
time_toposort += time.time() - t0
nodelist.reverse()
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes:
new_outputs = self.node_rewriter(fgraph, node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
fgraph.replace_all_validate(
list(zip(node.outputs, new_outputs)),
reason=self.__class__.__name__,
)
did_something = True
nb_replacement += 1
except InconsistencyError:
nb_inconsistency_replace += 1
nb_iter += 1
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before
callbacks_time = {}
for k, v in fgraph.execute_callbacks_times.items():
if k in callbacks_before:
callbacks_time[k] = v - callbacks_before[k]
else:
callbacks_time[k] = v
else:
validate_time = None
callback_time = None
callbacks_time = {}
return (
self,
nb_iter,
nb_replacement,
nb_inconsistency_replace,
validate_time,
callback_time,
callbacks_time,
time_toposort,
)
@classmethod
def print_profile(cls, stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, file=stream)
print(blanc, " nb_iter", prof[1], file=stream)
print(blanc, " nb_replacement", prof[2], file=stream)
print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
print(blanc, " validate_time", prof[4], file=stream)
print(blanc, " callback_time", prof[5], file=stream)
if prof[5] is not None and prof[5] > 1:
print(blanc, " callbacks_time", file=stream)
for i in sorted(prof[6].items(), key=lambda a: a[1])[::-1]:
if i[1] > 0:
print(blanc, " ", i)
print(blanc, " time_toposort", prof[7], file=stream)
if config.tensor__local_elemwise_fusion:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fast_run",
"fusion",
position=1,
)
compile.optdb.register( # type: ignore
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
else:
compile.optdb.register( # type: ignore
"elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
@register_canonicalize
@node_rewriter([Elemwise])
def local_useless_composite(fgraph, node):
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite
):
return
comp = node.op.scalar_op
idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
if len(idx) < len(node.outputs):
new_outputs = [comp.outputs[i] for i in idx]
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
return dict(zip([node.outputs[i] for i in idx], e))
import aesara.scalar.basic as aes
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.basic import Alloc, as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique
from aesara.tensor.rewriting.basic import register_canonicalize, register_useless
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False
if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False
uniqued_var = node.inputs[0]
if uniqued_var.ndim != 0:
return False
old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
alloc_var = node.inputs[0]
if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False
alloced_var, *alloc_shape = alloc_var.owner.inputs
new_unique, *_ = node.op.make_node(alloced_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
repeat_var = node.inputs[0]
if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False
repeated_var, *repeat_shape = repeat_var.owner.inputs
new_unique, *_ = node.op.make_node(repeated_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
second_var = node.inputs[0]
if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False
shape_var, seconded_var = second_var.owner.inputs
new_unique, *_ = node.op.make_node(seconded_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
......@@ -72,10 +72,8 @@ from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sq
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import (
FusionOptimizer,
broadcast_like,
encompasses_broadcastable,
fuse_seqopt,
local_fill_sink,
register_canonicalize,
register_specialize,
......@@ -84,6 +82,7 @@ from aesara.tensor.rewriting.basic import (
register_uncanonicalize,
register_useless,
)
from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
......
import traceback
from io import StringIO
from typing import Optional
from typing import cast as type_cast
from warnings import warn
import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, ancestors, equal_computations
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import (
GraphRewriter,
RemovalNodeRewriter,
check_chain,
copy_stack_trace,
node_rewriter,
)
from aesara.graph.utils import InconsistencyError, get_variable_trace_string
from aesara.tensor.basic import (
MakeVector,
as_tensor_variable,
cast,
constant,
extract_constant,
get_scalar_constant_value,
stack,
tensor_copy,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
register_useless,
topo_constant_folding,
)
from aesara.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
shape_i,
specify_shape,
unbroadcast,
)
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from aesara.tensor.type_other import NoneConst
class ShapeFeature(Feature):
r"""A `Feature` that tracks shape information in a graph.
This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with
`Shape_i` and `MakeVector` `Op`\s.
This `Feature` and its associated rewrites have several goals:
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the
input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute
something just to know how big it will be. Firstly, it is a waste
of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many rewrites refuse to
work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is
present, or else using a conservative default. An Op that
supports shape-lifting should define a infer_shape(self, fgraph, node,
input_shapes) function. The argument input_shapes is a tuple of
tuples... there is an interior tuple for each input to the node.
The tuple has as many elements as dimensions. The element in
position i of tuple j represents the i'th shape component of the
j'th input. The function should return a tuple of tuples. One
output tuple for each node.output. Again, the i'th element of the
j'th output tuple represents the output[j].shape[i] of the
function. If an output is not a TensorType, then None should be
returned instead of a tuple for that output.
For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
for doing size-driven rewrites. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes.
In cases where you cannot figure out the shape, raise a ShapeError.
Notes
-----
Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape
inference is 1) to mark shared inputs as having a particular
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
certain dimensions).
We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in rewrites, use the
``shape_of`` dictionary.
For example:
.. code-block:: python
try:
shape_of = fgraph.shape_feature.shape_of
except AttributeError:
# This can happen when the mode doesn't include the ShapeFeature.
return
shape_of_output_zero = shape_of[node.output[0]]
The ``shape_of_output_zero`` symbol will contain a tuple, whose
elements are either integers or symbolic integers.
TODO: check to see if the symbols are necessarily
non-constant... or are integer literals sometimes Aesara
constants?? That would be confusing.
"""
def get_node_infer_shape(self, node):
try:
shape_infer = node.op.infer_shape
except AttributeError:
shape_infer = self.default_infer_shape
try:
o_shapes = shape_infer(
self.fgraph, node, [self.shape_of[r] for r in node.inputs]
)
except ShapeError:
o_shapes = self.default_infer_shape(
self.fgraph, node, [self.shape_of[r] for r in node.inputs]
)
except NotImplementedError as e:
raise NotImplementedError(
"Code called by infer_shape failed raising a "
"NotImplementedError. Raising NotImplementedError to "
"indicate that a shape cannot be computed is no longer "
"supported, and one should now use ShapeError "
f"instead. The original exception message is: {e}"
).with_traceback(e.__traceback__)
except Exception as e:
msg = (
f"Failed to infer_shape from Op {node.op}.\nInput shapes: "
f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: "
f"{type(e)}\nException message: {str(e)}\nTraceback: {traceback.format_exc()}"
)
if config.on_shape_error == "raise":
raise Exception(msg).with_traceback(e.__traceback__)
else:
warn(msg)
o_shapes = self.default_infer_shape(
self.fgraph, node, [self.shape_of[r] for r in node.inputs]
)
return o_shapes
def get_shape(self, var, idx):
"""Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly ``shape_of[var][idx]``
as this method should update `shape_of` if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
r = self.shape_of[var][idx]
if (
r.owner
and isinstance(r.owner.op, Shape_i)
and r.owner.inputs[0] not in self.fgraph.variables
):
assert var.owner
node = var.owner
# recur on inputs
for i in node.inputs:
if getattr(i.type, "ndim", None) > 0:
self.get_shape(i, 0)
o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs)
# Only change the variables and dimensions that would introduce
# extra computation
for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out.type, "ndim"):
continue
merged_shps = list(self.shape_of[out])
changed = False
for i in range(out.type.ndim):
n_r = merged_shps[i]
if (
n_r.owner
and isinstance(n_r.owner.op, Shape_i)
and n_r.owner.inputs[0] not in self.fgraph.variables
):
changed = True
merged_shps[i] = new_shps[i]
if changed:
self.set_shape(out, merged_shps, override=True)
r = self.shape_of[var][idx]
return r
def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
if hasattr(r.type, "shape") and r.type.shape[i] is not None:
return constant(r.type.shape[i], dtype="int64")
else:
# Do not call make_node for test_value
s = Shape_i(i)(r)
try:
s = get_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r."""
if not hasattr(r.type, "ndim"):
# This happen for NoneConst.
return None
return tuple(self.shape_ir(i, r) for i in range(r.type.ndim))
def default_infer_shape(self, fgraph, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
"""
rval = []
for r in node.outputs:
try:
rval.append(self.shape_tuple(r))
except AttributeError:
rval.append(None)
return rval
def unpack(self, s_i, var):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
var: the variable that correspond to s_i. This is just for
error reporting.
"""
assert s_i is not None
if s_i == 1:
return self.lscalar_one
if isinstance(s_i, float) and int(s_i) == s_i:
s_i = int(s_i)
if isinstance(s_i, (np.integer, int)) or (
isinstance(s_i, np.ndarray) and s_i.ndim == 0
):
# this shape is a constant
if s_i < 0:
msg = "There is a negative shape in the graph!"
msg += get_variable_trace_string(var)
# The rest of the pipeline don't handle correctly this
# case. So we have 2 choices, stop compilation or
# consider the shape as unknown. As we have more
# chance to give the stack trace here then later, I
# choose that options as it would give better error
# message.
raise AssertionError(msg)
return constant(s_i, dtype="int64")
if isinstance(s_i, (tuple, list)):
# this dimension is the same as many of the inputs
# which tells us that if one of the inputs is known,
# the others all become known.
# TODO: should be implemented in Elemwise, and Dot
#
# worst case, we loop over shape_of and replace things
raise NotImplementedError(s_i)
# s_i is x.shape[i] for some x, we change it to shape_of[x][i]
if (
s_i.owner
and isinstance(s_i.owner.op, Subtensor)
and s_i.owner.inputs[0].owner
and isinstance(s_i.owner.inputs[0].owner.op, Shape)
):
assert s_i.type.ndim == 0
assert len(s_i.owner.op.idx_list) == 1
# The current Subtensor always put constant index in the graph.
# This was not True in the past. So call the Subtensor function
# that will return the right index.
idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list)
assert len(idx) == 1
idx = idx[0]
try:
i = get_scalar_constant_value(idx)
except NotScalarConstantError:
pass
else:
# Executed only if no exception was raised
x = s_i.owner.inputs[0].owner.inputs[0]
# x should already have been imported, and should be in shape_of.
s_i = self.shape_of[x][i]
if s_i.type.dtype in integer_dtypes:
if getattr(s_i.type, "ndim", 0):
raise TypeError("Shape element must be scalar", s_i)
return s_i
else:
raise TypeError(
"Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None)
)
def set_shape(self, r, s, override=False):
"""Assign the shape `s` to previously un-shaped variable `r`.
Parameters
----------
r : a variable
s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
"""
if not override:
assert r not in self.shape_of, "r already in shape_of"
if s is None:
self.shape_of[r] = s
else:
if not isinstance(s, (tuple, list)):
raise TypeError("shapes must be tuple/list", (r, s))
if r.type.ndim != len(s):
sio = StringIO()
aesara.printing.debugprint(r, file=sio, print_type=True)
raise AssertionError(
f"Something inferred a shape with {len(s)} dimensions "
f"for a variable with {int(r.type.ndim)} dimensions"
f" for the variable:\n{sio.getvalue()}"
)
shape_vars = []
for i in range(r.type.ndim):
if hasattr(r.type, "shape") and r.type.shape[i] is not None:
shape_vars.append(constant(r.type.shape[i], dtype="int64"))
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(shape_vars)
for sv in shape_vars:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def update_shape(self, r, other_r):
"""Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions.
"""
# other_r should already have a shape
assert other_r in self.shape_of, ("other_r not in shape_of", other_r)
other_shape = self.shape_of[other_r]
# If other_shape has no information, call is pointless.
if other_shape is None:
return
if r in self.shape_of:
r_shape = self.shape_of[r]
else:
# If no info is known on r's shape, use other_shape
self.set_shape(r, other_shape)
return
if (
other_r.owner
and r.owner
and other_r.owner.inputs == r.owner.inputs
and other_r.owner.op == r.owner.op
):
# We are doing a merge, so the two shape graphs will be the
# same. This is only done so that we call `ancestors` less
# frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
for i, ps in enumerate(other_shape):
if r_shape is None and other_shape:
merged_shape.append(other_shape[i])
elif (
ps.owner
and isinstance(getattr(ps.owner, "op", None), Shape_i)
and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r)
):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape.append(r_shape[i])
elif isinstance(r_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(r_shape[i])
elif isinstance(other_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(other_shape[i])
elif other_shape[i] == r_shape[i]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i])
elif r_shape[i] in ancestors([other_shape[i]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
# is strictly more complex. Such a substitution could also lead
# to cycles: if (in the future) r_shape[i] gets replaced by an
# expression of other_shape[i], other_shape[i] may end up
# depending on itself.
merged_shape.append(r_shape[i])
else:
merged_shape.append(other_shape[i])
assert all(
(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True)
)
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def set_shape_i(self, r, i, s_i):
"""Replace element i of shape_of[r] by s_i"""
assert r in self.shape_of
prev_shape = self.shape_of[r]
# prev_shape is a tuple, so we cannot change it inplace,
# so we build another one.
new_shape = []
for j, s_j in enumerate(prev_shape):
if j == i:
new_shape.append(self.unpack(s_i, r))
else:
new_shape.append(s_j)
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[idx]
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim)
)
self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
def init_r(self, r):
"""Register r's shape in the shape_of dictionary."""
if r not in self.shape_of:
self.set_shape(r, self.shape_tuple(r))
def make_vector_shape(self, r):
return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64")
def on_attach(self, fgraph):
if hasattr(fgraph, "shape_feature"):
raise AlreadyThere("This FunctionGraph already has a ShapeFeature")
if hasattr(self, "fgraph") and self.fgraph != fgraph:
raise Exception("This ShapeFeature is already attached to a graph")
self.fgraph = fgraph
fgraph.shape_feature = self
# Must be local to the object as otherwise we reuse the same
# variable for multiple fgraph!
self.lscalar_one = constant(1, dtype="int64")
assert self.lscalar_one.type.dtype == "int64"
self.fgraph = fgraph
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.shape_of = {}
# Variable ->
self.scheduled = {}
# shape var -> graph v
self.shape_of_reverse_index = {}
for node in fgraph.toposort():
self.on_import(fgraph, node, reason="on_attach")
def on_detach(self, fgraph):
self.shape_of = {}
self.scheduled = {}
self.shape_of_reverse_index = {}
self.fgraph = None
del fgraph.shape_feature
def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of:
# this is a revert, not really an import
for r in node.outputs + node.inputs:
assert r in self.shape_of
return
for i, r in enumerate(node.inputs):
# make sure we have shapes for the inputs
self.init_r(r)
o_shapes = self.get_node_infer_shape(node)
# this is packed information
# an element of o_shapes is either None or a tuple
# elements of the tuple can be either strings, or ints
if len(o_shapes) != len(node.outputs):
raise Exception(
(
f'The infer_shape method for the Op "{node.op}" returned a list '
f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} "
f" != len(node.outputs) = {len(node.outputs)}"
)
)
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` rewrite does not fail.
for sh_idx, sh in enumerate(o_shapes):
if sh is None:
continue
if not isinstance(sh, (list, tuple)):
raise ValueError(
f"infer_shape of {node} didn't return a list of"
f" list. It returned '{o_shapes}'"
)
new_shape = []
for i, d in enumerate(sh):
# Note: we ignore any shape element that is not typed (i.e.,
# does not have a 'dtype' attribute). This means there may
# still remain int elements that are int32 on 32-bit platforms,
# but this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix.
if getattr(d, "dtype", "int64") != "int64":
assert d.dtype in discrete_dtypes, (node, d.dtype)
assert str(d.dtype) != "uint64", node
new_shape += sh[len(new_shape) : i + 1]
if isinstance(d, Constant):
casted_d = constant(d.data, dtype="int64")
else:
casted_d = cast(d, "int64")
new_shape[i] = casted_d
if new_shape:
# We replace the shape with wrong dtype by the one with
# 'int64'.
new_shape += sh[len(new_shape) :]
o_shapes[sh_idx] = tuple(new_shape)
for r, s in zip(node.outputs, o_shapes):
self.set_shape(r, s)
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if new_r not in self.shape_of:
# It happen that the fgraph didn't called on_import for some
# new_r. This happen when new_r don't have an
# owner(i.e. it is a constant or an input of the graph)
# update_shape suppose that r and new_r are in shape_of.
self.init_r(new_r)
# This tells us that r and new_r must have the same shape if
# we didn't know that the shapes are related, now we do.
self.update_shape(new_r, r)
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
# In case 1, if r has a shape_i client, we will want to
# replace the shape_i of r with the shape of new_r. Say that
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
for (shpnode, idx) in fgraph.clients[r] + [(node, i)]:
if isinstance(getattr(shpnode, "op", None), Shape_i):
idx = shpnode.op.i
repl = self.shape_of[new_r][idx]
if repl.owner is shpnode:
# This mean the replacement shape object is
# exactly the same as the current shape object. So
# no need for replacement.
continue
if (
repl.owner
and repl.owner.inputs[0] is shpnode.inputs[0]
and isinstance(repl.owner.op, Shape_i)
and repl.owner.op.i == shpnode.op.i
):
# The replacement is a shape_i of the same
# input. So no need to do this equivalent
# replacement.
continue
if shpnode.outputs[0] in ancestors([repl]):
raise InconsistencyError(
"This substitution would insert a cycle in the graph:"
f"node: {node}, i: {i}, r: {r}, new_r: {new_r}"
)
self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update,
# then we should cancel it.
unscheduled = [k for k, v in self.scheduled.items() if v == r]
for k in unscheduled:
del self.scheduled[k]
# In either case, r could be in shape_of.values(), that is, r itself
# is the shape of something. In that case, we want to update
# the value in shape_of, to keep it up-to-date.
for v in self.shape_of_reverse_index.get(r, []):
# The reverse index is only approximate. It is not updated on
# deletion of variables, or on change_input so it might be the
# case that there are a few extra `v`'s in it that no longer have
# a shape of r or possibly have been deleted from shape_of
# entirely. The important thing is that it permits to recall
# all variables with r in their shape.
for ii, svi in enumerate(self.shape_of.get(v, [])):
if svi == r:
self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set()
def same_shape(
self,
x: Variable,
y: Variable,
dim_x: Optional[int] = None,
dim_y: Optional[int] = None,
) -> bool:
"""Return ``True`` if `x` and `y` have the same shape.
Parameters
==========
x
The `Variable` for which its shape is to be compared with `y`'s shape.
y
The `Variable` for which its shape is to be compared with `x`'s shape.
dim_x
If non ``None``, compare only the dimension of `x` equal to
`dim_x`.
dim_y
If non ``None``, compare only the dimension of `y` equal to
`dim_y`.
"""
sx = self.shape_of[x]
sy = self.shape_of[y]
if sx is None or sy is None:
return False
if dim_x is not None:
sx = [sx[dim_x]]
if dim_y is not None:
sy = [sy[dim_y]]
if len(sx) != len(sy):
return False
# Canonicalize the graphs so that comparisons are reasonable
# TODO FIXME: This should *not* need to be performed manually here.
# Instead, the shape information in `self.shape_of` should be operated
# upon alongside all the other elements in a `FunctionGraph` (e.g. as
# if `self.shape_of.values()` were additional outputs).
shapes_fg = FunctionGraph(
outputs=sx + sy,
# features=[self],
clone=True,
# copy_inputs=False,
)
from aesara.graph.rewriting.utils import rewrite_graph
canon_shapes_fg = type_cast(
FunctionGraph,
rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding),
)
canon_shapes = canon_shapes_fg.outputs
sx = canon_shapes[: len(sx)]
sy = canon_shapes[len(sx) :]
for dx, dy in zip(sx, sy):
if not equal_computations([dx], [dy]):
return False
return True
def clone(self):
return type(self)()
class ShapeOptimizer(GraphRewriter):
"""Rewriter that adds `ShapeFeature` as a feature."""
def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature())
def apply(self, fgraph):
pass
class UnShapeOptimizer(GraphRewriter):
"""Rewriter that removes `ShapeFeature` as a feature."""
def apply(self, fgraph):
for feature in fgraph._features:
if isinstance(feature, ShapeFeature):
fgraph.remove_feature(feature)
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
aesara.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op):
@node_rewriter([op])
def f(fgraph, node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if not check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
return f
register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
@register_useless
@register_canonicalize
@register_stabilize
@node_rewriter([Reshape])
def local_useless_reshape(fgraph, node):
"""
Remove two kinds of useless reshape.
Remove Reshape when both the input and output have a single dimension.
Remove Reshape when reshaping to the shape of the input.
"""
op = node.op
if not isinstance(op, Reshape):
return False
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if inp.ndim != output.ndim:
return False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if inp.ndim == 1 and output.ndim == 1 and inp.broadcastable == output.broadcastable:
return [inp]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == inp:
return [inp]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None)
nb_m1 = 0
shape_match = [False] * inp.ndim
for dim in range(inp.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Shape_i)
and outshp_i.owner.op.i == dim
and outshp_i.owner.inputs[0] == inp
):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim
):
subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == inp:
shape_match[dim] = True
continue
# Match 1 if input.broadcastable[dim] is True
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.broadcastable[dim] and cst_outshp_i == 1:
shape_match[dim] = True
continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1)
):
shape_match[dim] = True
continue
if all(shape_match) and nb_m1 <= 1:
return [inp]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False
@register_canonicalize
@node_rewriter([Reshape])
def local_reshape_to_dimshuffle(fgraph, node):
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
The goal is to avoid using reshape to add or remove broadcastable
dimensions, but use dimshuffle instead, so dimshuffles can cancel out
or be removed later on.
For example:
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
"""
op = node.op
if not isinstance(op, Reshape):
return False
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
dimshuffle_new_order = []
new_output_shape = []
index = 0 # index over the output of the new reshape
for i in range(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = extract_constant(
output_shape[i], only_process_constants=False, elemwise=False
)
if dim == 1:
dimshuffle_new_order.append("x")
else:
dimshuffle_new_order.append(index)
new_output_shape.append(dim)
index = index + 1
if index != output.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
copy_stack_trace(output, new_node)
return new_node
@register_canonicalize
@register_stabilize
@node_rewriter([Reshape])
def local_reshape_lift(fgraph, node):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Notes
-----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
"""
if (
isinstance(node.op, Reshape)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Elemwise)
and len(node.inputs[0].owner.inputs) == 1
):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy stacktrace from previous Reshape op, as an error in new
# Reshape op could only have been caused by old one.
copy_stack_trace(node.outputs, r)
e = node.inputs[0].owner.op(r)
# Copy stacktrace from both previous Reshape and UnaryElemwise op
# because an error in new cg could have been caused by either ops.
copy_stack_trace(node.outputs + node.inputs, e)
return [e]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_useless
@register_canonicalize
@node_rewriter([SpecifyShape])
def local_merge_consecutive_specify_shape(fgraph, node):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
"""
if not isinstance(node.op, SpecifyShape):
return False
obj = node.inputs[0]
if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)):
return False
inner_obj, *shape = obj.owner.inputs
for dim, sh in enumerate(node.inputs[1:]):
if not NoneConst.equals(sh):
shape[dim] = sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
# the same.
return [specify_shape(inner_obj, shape)]
@register_useless
@register_canonicalize
@node_rewriter([Shape])
def local_Shape_of_SpecifyShape(fgraph, node):
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
if not isinstance(node.op, Shape):
return False
specified_shape = node.inputs[0]
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
return False
x, *shape = specified_shape.owner.inputs
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
return [stack(shape).astype(np.int64)]
@register_useless
@register_canonicalize
@node_rewriter([Shape_i])
def local_Shape_i_of_broadcastable(fgraph, node):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
if not isinstance(node.op, Shape_i):
return False
shape_arg = node.inputs[0]
if not isinstance(shape_arg.type, TensorType):
return False
if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)]
@register_specialize
@register_canonicalize
@node_rewriter([Shape])
def local_shape_to_shape_i(fgraph, node):
if isinstance(node.op, Shape):
if not hasattr(fgraph, "shape_feature"):
return
shape_feature = fgraph.shape_feature
ret = shape_feature.make_vector_shape(node.inputs[0])
# We need to copy over stack trace from input to output
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_specialize
@register_canonicalize
@node_rewriter([Shape_i])
def local_track_shape_i(fgraph, node):
if not isinstance(node.op, Shape_i):
return False
try:
shape_feature = fgraph.shape_feature
except AttributeError:
return False
if node not in shape_feature.scheduled:
return False
# Don't unschedule node as it could be reinserted in the
# fgraph as we don't change it in the shapefeature internal
# structure.
replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]]
@register_canonicalize
@node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if not isinstance(op, Reshape):
return False
if not (
node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, DimShuffle)
):
return False
new_order = node.inputs[0].owner.op.new_order
inp = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in range(len(new_order_of_nonbroadcast) - 1)
)
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(inp, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
TODO: Implement equivalent rewrite for SpecifyShape
"""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable:
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return [x]
else:
# Keep the flags that modify something
new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
if new_axes == node.op.axes:
# All flags are useful
return None
else:
r = unbroadcast(x, *new_axes)
# Copy over stacktrace from previous output
copy_stack_trace(node.outputs, r)
return [r]
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_unbroadcast_lift(fgraph, node):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op = node.op
if not isinstance(op, Unbroadcast):
return False
inp = node.inputs[0]
inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
if len(fgraph.clients.get(inp, ())) == 1:
unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
copy_stack_trace(node.outputs, unbroadcasted)
rval = inode.op.make_node(unbroadcasted).outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
if inode and isinstance(inode.op, Unbroadcast):
# Merge axis of each unbroadcast
axis = tuple(set(inode.op.axes).union(set(op.axes)))
iinput = inode.inputs[0]
rval = [unbroadcast(iinput, *axis)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
......@@ -63,7 +63,9 @@ def shape_of_variables(fgraph, input_shapes):
"""
if not hasattr(fgraph, "shape_feature"):
fgraph.attach_feature(aesara.tensor.rewriting.basic.ShapeFeature())
from aesara.tensor.rewriting.shape import ShapeFeature
fgraph.attach_feature(ShapeFeature())
input_dims = [
dimension
......
......@@ -21,7 +21,7 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sigmoid
from aesara.tensor.math import sum as at_sum
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.rewriting.basic import ShapeOptimizer
from aesara.tensor.rewriting.shape import ShapeOptimizer
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools
......
......@@ -55,7 +55,7 @@ from aesara.tensor.random.basic import (
wald,
weibull,
)
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param
......
import contextlib
import copy
import numpy as np
......@@ -10,20 +9,15 @@ import aesara.tensor as at
from aesara import shared
from aesara.compile import optdb
from aesara.compile.function import function
from aesara.compile.mode import OPT_NONE, Mode, get_default_mode, get_mode
from aesara.compile.mode import get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
from aesara.graph.rewriting.basic import check_stack_trace, out2in
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.raise_op import Assert, CheckAndRaise
from aesara.scalar.basic import Composite
from aesara.tensor.basic import (
Alloc,
Join,
......@@ -31,21 +25,15 @@ from aesara.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
alloc,
as_tensor_variable,
join,
second,
tile,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from aesara.tensor.math import (
add,
bitwise_and,
bitwise_or,
bitwise_xor,
cos,
cosh,
dot,
eq,
exp,
......@@ -53,46 +41,32 @@ from aesara.tensor.math import (
ge,
gt,
int_div,
invert,
iround,
le,
log,
log2,
log10,
lt,
maximum,
minimum,
mul,
neg,
neq,
)
from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import reciprocal
from aesara.tensor.math import round as at_round
from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import softplus, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import (
ShapeFeature,
assert_op,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
local_useless_alloc,
local_useless_dimshuffle_in_reshape,
local_useless_elemwise,
local_useless_reshape,
register_specialize,
)
from aesara.tensor.rewriting.math import local_lift_transpose_through_dot
from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
Unbroadcast,
reshape,
shape,
specify_shape,
unbroadcast,
)
......@@ -102,17 +76,14 @@ from aesara.tensor.subtensor import (
advanced_inc_subtensor,
advanced_inc_subtensor1,
inc_subtensor,
set_subtensor,
)
from aesara.tensor.type import (
TensorType,
dmatrices,
dmatrix,
dscalar,
dvector,
fmatrix,
fscalar,
fvector,
imatrices,
iscalar,
iscalars,
......@@ -129,7 +100,6 @@ from aesara.tensor.type import (
tensor4,
values_eq_approx_remove_nan,
vector,
vectors,
)
from tests import unittest_tools as utt
......@@ -139,8 +109,6 @@ if rewrite_mode == "FAST_COMPILE":
rewrite_mode = "FAST_RUN"
rewrite_mode = get_mode(rewrite_mode)
dimshuffle_lift = out2in(local_dimshuffle_lift)
_stabilize_rewrites = RewriteDatabaseQuery(include=["fast_run"])
_stabilize_rewrites.position_cutoff = 1.51
_stabilize_rewrites = optdb.query(_stabilize_rewrites)
......@@ -153,10 +121,6 @@ _fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"])
_fast_run_rewrites = optdb.query(_fast_run_rewrites)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
def rewrite(g, level="fast_run"):
if level == "fast_run":
_fast_run_rewrites.rewrite(g)
......@@ -169,1124 +133,6 @@ def rewrite(g, level="fast_run"):
return g
def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
x = TensorType(shape=xbc, dtype="float64")("x")
y = TensorType(shape=ybc, dtype="float64")("y")
z = TensorType(shape=zbc, dtype="float64")("z")
return x, y, z
class TestDimshuffleLift:
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = FunctionGraph([x], [e])
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
g = FunctionGraph([x], [e])
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
g = FunctionGraph([x], [e])
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty
def test_lift(self):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z
g = FunctionGraph([x, y, z], [e])
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
rewrite_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
rewrite_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift.rewrite(g)
assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_recursive_lift(self):
v = vector(dtype="float64")
m = matrix(dtype="float64")
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
init_str_g = (
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, (None,))>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, (None, None))>, "
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert str(g) == init_str_g
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
rewrite_str_g = (
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, (None, None))>), "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
)
assert str(new_g) == rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(new_g, ops_to_check="all")
def test_useless_dimshuffle(self):
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
def test_dimshuffle_on_broadcastable(self):
x, y, z = inputs([False, True], [True, False, True], [False, False, True])
u = at.constant(1)
ds_x = ds(x, (0, "x")) # useless
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # useful
ds_u = ds(u, ("x")) # useful
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift.rewrite(g)
assert (
str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
def test_local_useless_dimshuffle_in_reshape():
vec = TensorType(shape=(False,), dtype="float64")("vector")
mat = TensorType(shape=(False, False), dtype="float64")("mat")
row = TensorType(shape=(True, False), dtype="float64")("row")
col = TensorType(shape=(False, True), dtype="float64")("col")
reshape_dimshuffle_vector = reshape(vec.dimshuffle("x", 0), vec.shape)
reshape_dimshuffle_mat = reshape(mat.dimshuffle("x", 0, "x", 1), mat.shape)
reshape_dimshuffle_row = reshape(row.dimshuffle(1, "x"), row.shape)
reshape_dimshuffle_col = reshape(col.dimshuffle(0), col.shape)
g = FunctionGraph(
[vec, mat, row, col],
[
reshape_dimshuffle_vector,
reshape_dimshuffle_mat,
reshape_dimshuffle_row,
reshape_dimshuffle_col,
],
)
assert str(g) == (
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col)))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h
class TestFusion:
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(get_default_mode().linker, rewrites)
_shared = staticmethod(shared)
topo_exclude = ()
def my_init(dtype="float64", num=0):
return np.zeros((5, 5), dtype=dtype) + num
fw, fx, fy, fz = [
tensor(dtype="float32", shape=[False] * 2, name=n) for n in "wxyz"
]
dw, dx, dy, dz = [
tensor(dtype="float64", shape=[False] * 2, name=n) for n in "wxyz"
]
ix, iy, iz = [tensor(dtype="int32", shape=[False] * 2, name=n) for n in "xyz"]
fv = fvector("v")
fs = fscalar("s")
fwv = my_init("float32", 1)
fxv = my_init("float32", 2)
fyv = my_init("float32", 3)
fzv = my_init("float32", 4)
fvv = _asarray(np.random.random(5), dtype="float32")
fsv = np.asarray(np.random.random(), dtype="float32")
dwv = my_init("float64", 5)
ixv = _asarray(my_init(num=60), dtype="int32")
iyv = _asarray(my_init(num=70), dtype="int32")
izv = _asarray(my_init(num=70), dtype="int32")
fwx = fw + fx
ftanx = tan(fx)
@pytest.mark.parametrize(
"case",
[
(
fx + fy + fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + fzv,
"float32",
), # 0
(
fx * fy * fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv * fyv * fzv,
"float32",
), # 1
(
fx + fy * fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv * fzv,
"float32",
), # 2
(
fx * fy + fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv * fyv + fzv,
"float32",
), # 3
(
fw + fx + fy + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + fx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
), # 5
(
((fw + fx) + fy) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + (fx + fy)) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + (fx + fy) + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
fw + (fx + (fy + fz)),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + fx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
), # 10
(
fw * fx * fy * fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv * fxv * fyv * fzv,
"float32",
),
(
fw + fx * fy * fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv * fyv * fzv,
"float32",
),
(
fx + fy * fz * fx,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv * fzv * fxv,
"float32",
),
(
fx * fy + fz + fy,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv * fyv + fzv + fyv,
"float32",
),
(
fx * fy * fz * fw + fx + fy + fz + fw,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv,
"float32",
), # 15
# test with constant
(
(fw + fx) + (fy + fz) + 2.0,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
((fw + fx) + 2.0 + fy) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
(fw + (fx + 2.0 + fy)) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
(fw + (fx + fy) + 2 + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
fw + (fx + (fy + fz) + 2.0),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
), # 20
(
2 + (fw + fx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
# mix float32 and float64
(
2 + (dw + fx) + (fy + fz),
(dw, fx, fy, fz),
(dwv, fxv, fyv, fzv),
1,
dwv + fxv + fyv + fzv + 2,
"float64",
),
(
2 + (fw + dw) + (fy + fz),
(fw, dw, fy, fz),
(fwv, dwv, fyv, fzv),
1,
fwv + dwv + fyv + fzv + 2,
"float64",
),
(
2 + (fw + fx) + (dw + fz),
(fw, fx, dw, fz),
(fwv, fxv, dwv, fzv),
1,
fwv + fxv + dwv + fzv + 2,
"float64",
),
(
2 + (fw + fx) + (fy + dw),
(fw, fx, fy, dw),
(fwv, fxv, fyv, dwv),
1,
fwv + fxv + fyv + dwv + 2,
"float64",
), # 25
# test when their is other op then elemwise.
(
(fwx.sum()) + (fwx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
4,
(fwv + fxv).sum() + fwv + fxv + fyv + fzv,
"float32",
),
# test other elemwise op
(
fx + fy + cos(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.cos(fzv),
"float32",
),
(
fx + fy + cosh(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.cosh(fzv),
"float32",
),
(
fx + fy + abs(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.absolute(fzv),
"float32",
),
(
ix + iy + abs(iz),
(ix, iy, iz),
(ixv, iyv, izv),
1,
ixv + iyv + np.absolute(izv),
"int32",
), # 30
(
fx + fy + log(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.log(fzv),
"float32",
),
(
fx + fy + log2(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.log2(fzv),
"float32",
),
(
fx + fy + log10(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.log10(fzv),
"float32",
),
(
fx + fy**fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv**fzv,
"float32",
), # pow
(
fx + fy + exp(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.exp(fzv),
"float32",
), # 35
(
fx - fy - fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv - fzv,
"float32",
),
(
fx - (fy / fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv / fzv),
"float32",
),
(
fx - true_div(fy, 2),
(fx, fy),
(fxv, fyv),
1,
fxv - (fyv / 2),
"float32",
),
(
fx - true_div(fy, fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv / fzv),
"float32",
),
(
fx - int_div(ix * 100, iy * 1000),
(fx, ix, iy),
(fxv, ixv, iyv),
1,
fxv - ((ixv * 100) // (iyv * 1000)),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
), # 40
(fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"),
(
fx - (fy % fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv % fzv),
"float32",
),
(
fx - (fy > fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv > fzv),
"float32",
),
(
fx - (fy >= fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv >= fzv),
"float32",
),
(
fx - (fy < fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv < fzv),
"float32",
), # 45
(
fx - (fy <= fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv <= fzv),
"float32",
),
(
fx - eq(fy, fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv == fzv),
"float32",
),
(
fx - neq(fy, fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv != fzv),
"float32",
),
(
fx - fy + tan(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.tan(fzv),
"float32",
),
(
fx - fy + tanh(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.tanh(fzv),
"float32",
), # 50
(
fx - fy + sin(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.sin(fzv),
"float32",
),
(
fx - fy + sinh(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.sinh(fzv),
"float32",
),
(
fx - fy + sqr(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + (fzv * fzv),
"float32",
),
(
fx - fy + sqrt(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.sqrt(fzv),
"float32",
),
(
fx - fy + reciprocal(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + (1 / fzv),
"float32",
), # 55
(
fx - fy + neg(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + (-fzv),
"float32",
),
(
fx - fy + at_round(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.round(fzv),
"float32",
),
(
ix - iy + iround(fz),
(ix, iy, fz),
(ixv, iyv, fzv),
1,
ixv - iyv + np.round(fzv),
"int64",
),
# Bit op
(
fx - bitwise_or(iy, iz),
(fx, iy, iz),
(fxv, iyv, izv),
1,
fxv - (iyv | izv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
),
(
fx - xor(iy, iz),
(fx, iy, iz),
(fxv, iyv, izv),
1,
fxv - (iyv ^ izv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
), # 60
(
fx - bitwise_and(iy, iz),
(fx, iy, iz),
(fxv, iyv, izv),
1,
fxv - (iyv & izv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
),
(
fx - invert(iy),
(fx, iy),
(fxv, iyv),
1,
fxv - (~iyv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
),
(
fx - at.cast(fy, dtype="float64"),
(fx, fy),
(fxv, fyv),
1,
fxv - np.asarray(fyv, "float64"),
"float64",
),
(
at_pow(fx * fy + fz, fx * fy),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
np.power(fxv * fyv + fzv, fxv * fyv),
"float32",
),
(
fv + fy**fz,
(fv, fy, fz),
(fvv, fyv, fzv),
2,
fvv + fyv**fzv,
"float32",
), # fused with a dimshuffle #65
(
fv - fy + tanh(fz),
(fv, fy, fz),
(fvv, fyv, fzv),
2,
fvv - fyv + np.tanh(fzv),
"float32",
), # fused with a dimshuffle
# Cases where the same input is reused many times.
(
mul(fx, fx, fx, fx),
(fx,),
(fxv,),
1,
fxv * fxv * fxv * fxv,
"float32",
),
(
mul(fx, ftanx, ftanx),
(fx,),
(fxv,),
1,
fxv * np.tan(fxv) * np.tan(fxv),
"float32",
),
(
mul(fx, ftanx, ftanx, fx),
(fx,),
(fxv,),
1,
fxv * np.tan(fxv) * np.tan(fxv) * fxv,
"float32",
),
(
mul(ftanx, ftanx, fx + fy),
(fx, fy),
(fxv, fyv),
1,
np.tan(fxv) * np.tan(fxv) * (fxv + fyv),
"float32",
), # 70
# Cases with different broadcast pattern. They should not
# be merged as this would duplicate computation
# The graph should have 2 elemwise and 1 dimshuffle
(
fx * sin(fs),
(fx, fs),
(fxv, fsv),
3,
fxv * np.sin(fsv),
"float32",
),
],
)
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
"""Verify that `Elemwise` fusion works."""
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat):
out = f(*val_inputs)
else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
for x in range(nb_repeat):
f(*val_inputs)
out = out.get_value()
atol = 1e-8
if out_dtype == "float32":
atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
if assert_len_topo:
assert len(topo_) == nb_elemwise
if nb_elemwise == 1:
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs
)
assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype
def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
inpts = vectors(["i%i" % i for i in range(35)])
# Make an elemwise graph looking like:
# sin(i34 + sin(i33 + sin(... i1 + sin(i0) ...)))
out = sin(inpts[0])
for idx in range(1, 35):
out = sin(inpts[idx] + out)
with config.change_flags(cxx=""):
f = function(inpts, out, mode=self.mode)
# Make sure they all weren't fused
composite_nodes = [
node
for node in f.maker.fgraph.toposort()
if isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite)
]
assert not any(len(node.inputs) > 31 for node in composite_nodes)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_big_fusion(self):
# In the past, pickle of Composite generated in that case
# crashed with max recursion limit. So we were not able to
# generate C code in that case.
factors = []
sd = dscalar()
means = dvector()
cst_05 = at.constant(0.5)
cst_m05 = at.constant(-0.5)
cst_2 = at.constant(2)
cst_m2 = at.constant(-2)
ones = at.constant(np.ones(10))
n = 85
if config.mode in ["DebugMode", "DEBUG_MODE"]:
n = 10
for i in range(n):
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
cst_05 * (sd**cst_m2) / np.pi
)
factors.append(at_sum(f))
logp = add(*factors)
vars = [sd, means]
# Make sure that C compilation is used
mode = Mode("cvm", self.rewrites)
dlogp = function(vars, [aesara.grad(logp, v) for v in vars], mode=mode)
# Make sure something was fused
assert any(
isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite)
for node in dlogp.maker.fgraph.toposort()
)
def test_add_mul_fusion_inplace(self):
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, rewrites)
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=mode)
topo = [n for n in f.maker.fgraph.toposort()]
assert len(topo) == 2
assert topo[-1].op.inplace_pattern
new_out = f.maker.fgraph.outputs[0]
assert isinstance(new_out.owner.op, Elemwise)
assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add)
assert len(new_out.owner.inputs) == 4
# TODO: Do we really need to do this?
_ = f(
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
# This custom `Op` has no `c_code` method
class NoCCodeOp(aes.basic.UnaryScalarOp):
def impl(self, x):
return x * 2
no_c_code_op = Elemwise(NoCCodeOp(aes.basic.upgrade_to_float))
mode = Mode(linker="cvm")
mode._optimizer = mode._optimizer.including(
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
)
x = vector()
out = x * no_c_code_op(x + 1)
f = function([x], out, mode=mode)
assert not any(
isinstance(getattr(n.op, "scalar_op"), aes.basic.Composite)
for n in f.maker.fgraph.toposort()
)
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, rewrites)
x, y, z = dmatrices("xyz")
x.tag.test_value = test_value
y.tag.test_value = test_value
z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=mode)
if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
class TimesN(aes.basic.UnaryScalarOp):
"""
Used in test TestCompositeCodegen
Must be outside of the class, otherwise, the c cache code can't
pickle this class and this cause stuff printing during test.
"""
def __eq__(self, other):
return super().__eq__(other) and self.n == other.n
def __hash__(self):
return super().__hash__() ^ hash(self.n)
def __init__(self, n, *args, **kwargs):
self.n = n
aes.basic.UnaryScalarOp.__init__(self, *args, **kwargs)
def impl(self, x):
return x * self.n
def c_support_code_apply(self, node, nodename):
n = str(self.n)
return (
"""
float %(nodename)s_timesn(float x) { return x * %(n)s; }
"""
% locals()
)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
return f"{z} = {name}_timesn({x});"
class TestCompositeCodegen:
"""
Test The Composite Ops code generation in a case where there is multiple
scalar ops with support code.
"""
def setup_method(self):
upgrade_to_float = aes.basic.upgrade_to_float
self.scal_times_2 = TimesN(2, upgrade_to_float, name="times_2")
self.times_2 = Elemwise(self.scal_times_2, name="times_2")
self.scal_times_3 = TimesN(3, upgrade_to_float, name="times_3")
self.times_3 = Elemwise(self.scal_times_3, name="times_3")
self.x = fvector()
def test_nested_composite(self):
y = self.times_2(self.x)
z = self.times_3(y)
f = function([self.x], z)
if config.mode != "FAST_COMPILE":
assert len(f.maker.fgraph.toposort()) == 1
fval = f([1, 2, 3])
assert np.all(fval == [6, 12, 18])
def test_local_useless_composite(self):
x = aes.float32()
c = aes.Composite([x], [x + 1, x - 1])
X = matrix()
o = Elemwise(scalar_op=c)(X)
mode = get_default_mode().including("local_useless_composite")
f = function([X], o[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[2.0]])
f = function([X], o[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[0.0]])
def test_local_useless_slice():
# test a simple matrix
x = matrix("x")
......@@ -1616,191 +462,6 @@ class TestLocalUselessIncSubtensorAlloc:
assert check_stack_trace(f2, ops_to_check="last")
class TestShapeRewriter:
def test_basic(self):
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
v = vector()
m = matrix()
f = function([v, m], (v + m).shape, mode=mode)
for node in f.maker.fgraph.toposort():
assert node.op != add
def test_constant(self):
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
v = vector()
f = function([v], v.dimshuffle("x", "x", 0).shape[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
@staticmethod
def max_pool_c01b(c01b, pool_shp, pool_stride, img_shp):
"""
Like max_pool but with input using axes ('c', 0, 1, 'b')
(Alex Krizhevsky format)
pool_shp, pool_stride and img_shp are int that represent
the same shp in x and y.
"""
mx = None
# Compute index in pooled space of last needed pool
# (needed = each input pixel must appear in at least one pool)
def last_pool(im_shp, p_shp, p_strd):
rval = int(np.ceil(float(im_shp - p_shp) / p_strd))
assert p_strd * rval + p_shp >= im_shp
assert p_strd * (rval - 1) + p_shp < im_shp
return rval
# Compute starting row of the last pool
last_pool_r = last_pool(img_shp, pool_shp, pool_stride) * pool_stride
# Compute number of rows needed in img for all indexes to work out
required_r = last_pool_r + pool_shp
last_pool_c = last_pool(img_shp, pool_shp, pool_stride) * pool_stride
required_c = last_pool_c + pool_shp
wide_infinity = at.alloc(
-np.inf, c01b.shape[0], required_r, required_c, c01b.shape[3]
)
c01b = set_subtensor(wide_infinity[:, 0:img_shp, 0:img_shp, :], c01b)
for row_within_pool in range(pool_shp):
row_stop = last_pool_r + row_within_pool + 1
for col_within_pool in range(pool_shp):
col_stop = last_pool_c + col_within_pool + 1
cur = c01b[
:,
row_within_pool:row_stop:pool_stride,
col_within_pool:col_stop:pool_stride,
:,
]
if mx is None:
mx = cur
else:
mx = maximum(mx, cur)
return mx
def test_broadcasted_dims(self):
# This test a case that caused a crash during rewriting
shp = (1, 1, 1, 1)
rng = np.random.default_rng(utt.fetch_seed())
a = shared(rng.random(shp).astype(config.floatX))
out = self.max_pool_c01b(a, 1, 1, 1)
# max_pool_c01b use -inf and this will trigger DebugMode error.
mode = copy.copy(get_default_mode())
mode.check_isfinite = False
f = function([], out, mode=mode)
f()
def test_constant_merge(self):
# This test the error in gh-1122 that is a caused by the
# combination of merge rewriter and ShapeFeature.
x = at.constant([0, 0])
y = x[1:]
x1 = x - at.join(0, y, y)
x1.eval()
def test_local_track_shape_i(self):
class IdentityNoShape(Op):
"""Op that does not infer the output shape from the input one"""
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x.copy()
# def infer_shape(self, fgraph, node, (xshp,)):
# return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])]
identity_noshape = IdentityNoShape()
class IdentityShape(Op):
"""Op that does infer the output shape from the input one"""
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x.copy()
def infer_shape(self, fgraph, node, xshp_):
# Could also just return.
(xshp,) = xshp_
return (xshp,)
identity_shape = IdentityShape()
@node_rewriter([IdentityNoShape])
def local_identity_noshape_to_identity_shape(fgraph, node):
"""Transform the first `Op` into the second."""
if isinstance(node.op, IdentityNoShape):
return [identity_shape(node.inputs[0])]
mode = get_default_mode().including("ShapeOpt", "specialize")
rng = np.random.default_rng(utt.fetch_seed())
x = tensor3("x")
ins_x = identity_noshape(x)
# Without the rewrite
f = function([x], ins_x.shape, mode=mode)
xval = rng.standard_normal((3, 4, 7)).astype(config.floatX)
assert np.all(f(xval) == [3, 4, 7])
f_ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(f_ops) == 5
assert identity_noshape in f_ops
assert identity_shape not in f_ops
# Register the rewrite
register_specialize(local_identity_noshape_to_identity_shape)
mode = get_default_mode().including("ShapeOpt", "specialize")
# The `identity_shape` hOph should not be needed anymore to compute
# the shape
g = function([x], ins_x.shape, mode=mode)
xval = rng.standard_normal((6, 1, 2)).astype(config.floatX)
assert np.all(g(xval) == [6, 1, 2])
g_ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(g_ops) == 4
assert identity_noshape not in g_ops
assert identity_shape not in g_ops
# Test multiple applications of an `Op` without an `Op.infer_shape`
ins_x3 = identity_noshape(identity_noshape(identity_noshape(x)))
h = function([x], ins_x3.shape, mode=mode)
xval = rng.standard_normal((6, 1, 2)).astype(config.floatX)
assert np.all(h(xval) == [6, 1, 2])
h_ops = [node.op for node in h.maker.fgraph.toposort()]
assert len(h_ops) == 4
assert identity_noshape not in h_ops
assert identity_shape not in h_ops
def test_no_shapeopt(self):
"""Test that a basic example works even when `ShapeOpt` is excluded."""
X = matrix()
expr = X.shape[0]
mode = get_default_mode().excluding("ShapeOpt")
f = function([X], expr, mode=mode)
# FIXME: This is not a good test.
f([[1, 2], [2, 3]])
class TestUselessCheckAndRaise:
def test_basic(self):
mode = get_default_mode().including(
......@@ -2739,136 +1400,6 @@ def test_local_flatten_lift(i):
assert isinstance(topo[-1].op, Elemwise)
class TestReshape:
def setup_method(self):
self.mode = rewrite_mode
self.op = Reshape
def test_local_reshape(self):
a = fmatrix()
b = self.op(3)(a, [2, 3, 4])
c = self.op(1)(b, [24])
f = function([a], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, self.op) for node in topo) == 1
# Check stack trace
assert check_stack_trace(f, ops_to_check=[self.op])
class TestLocalUselessReshape:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
def test_0(self):
mode = get_default_mode().including("local_useless_reshape")
i = iscalar("i")
m = at.mgrid[
0:i,
]
f = function([i], m, mode=mode)
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
def test_1(self):
x = matrix("x")
r = x.reshape(x.shape)
m0 = get_default_mode()
m1 = m0.including("local_useless_reshape")
f1 = function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
m2 = m1.excluding("ShapeOpt")
f2 = function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
# We do not need tests checking that stack traces are copied over,
# because local_useless_reshape only removes nodes from the graph
def test_2(self):
x = matrix("x")
r = x.reshape([Shape_i(i)(x) for i in range(x.ndim)])
m0 = get_default_mode()
m1 = m0.including("local_useless_reshape")
f1 = function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
m2 = m1.excluding("ShapeOpt")
f2 = function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
def test_m1(self):
x = matrix("x")
r = x.reshape((x.shape[0], -1))
m0 = get_default_mode()
m1 = m0.including("local_useless_reshape")
f1 = function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
m2 = m1.excluding("ShapeOpt")
f2 = function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
class TestLocalReshapeToDimshuffle:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
def test_1(self):
reshape_lift = out2in(local_reshape_to_dimshuffle)
useless_reshape = out2in(local_useless_reshape)
x = shared(self.rng.standard_normal((4,)))
y = shared(self.rng.standard_normal((5, 6)))
reshape_x = reshape(x, (1, 4))
reshape_y = reshape(y, (1, 5, 1, 6, 1, 1))
g = FunctionGraph([x, y], [reshape_x, reshape_y])
assert str(g) == (
"FunctionGraph(Reshape{2}"
"(<TensorType(float64, (None,))>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, (None, None))>, "
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift.rewrite(g)
useless_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,0,x,1,x,x}"
"(Reshape{2}(<TensorType(float64, (None, None))>, "
"TensorConstant{[5 6]})))"
)
# Check stacktrace was copied over correctly after the rewrite was applied
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
def test_local_reshape_lift():
x = tensor4()
out = exp(x).reshape([x.size])
assert out.ndim == 1
mode = get_default_mode()
mode = mode.including("local_reshape_lift")
f = function([x], out, mode=mode)
f(np.random.random((5, 4, 3, 2)).astype(config.floatX))
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-2].op, Reshape)
assert isinstance(topo[-1].op, Elemwise)
assert check_stack_trace(f, ops_to_check="last")
class TestLiftTransposeThroughDot:
def simple_rewrite(self, g):
out2in(local_useless_elemwise).rewrite(g)
......@@ -2918,160 +1449,6 @@ def test_local_upcast_elemwise_constant_inputs():
function([v], true_div(v, 2))
class TestShapeI(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
def test_perform(self):
rng = np.random.default_rng(utt.fetch_seed())
advec = vector()
advec_val = rng.random((3)).astype(config.floatX)
f = function([advec], Shape_i(0)(advec))
out = f(advec_val)
utt.assert_allclose(out, advec_val.shape[0])
admat = matrix()
admat_val = rng.random((4, 3)).astype(config.floatX)
for i in range(2):
f = function([admat], Shape_i(i)(admat))
out = f(admat_val)
utt.assert_allclose(out, admat_val.shape[i])
def test_infer_shape(self):
admat = matrix()
admat_val = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check([admat], [Shape_i(0)(admat)], [admat_val], Shape_i)
self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i)
class TestSameShape:
def test_scalar(self):
x = scalar()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector(self):
x = vector()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_no_static_shapes(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
# We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)
@pytest.mark.parametrize(
"y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
y = at.tensor(dtype="floatX", shape=(y_dim_0, None))
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0)
assert not shape_feature.same_shape(x, o, 1, 1)
def test_vector_dim_err(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
with pytest.raises(IndexError):
shape_feature.same_shape(x, o, 1, 0)
with pytest.raises(IndexError):
shape_feature.same_shape(x, o, 0, 1)
@pytest.mark.parametrize(
"shape",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape(shape):
x = vector()
s = specify_shape(x, shape).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert shape in fgraph.variables
@pytest.mark.parametrize(
"s1",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape_partial(s1):
x = matrix()
s = specify_shape(x, (s1, None)).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = rewrite_graph(fgraph, clone=False)
assert x in fgraph.variables
assert s1 in fgraph.variables
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True])
s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1
# A test for a non-`TensorType`
class MyType(Type):
ndim = 1
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
class MyVariable(Variable):
pass
x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert fgraph.outputs[0] == s
def test_assert_op_gradient():
x = vector("x")
assert_op = Assert()
......@@ -3183,283 +1560,6 @@ def test_local_useless_alloc():
assert isinstance(topo[-1].op, Alloc)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_scalar(return_index, return_counts, return_inverse):
x = dscalar()
y = unique(
x,
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=None,
)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"]
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, DimShuffle)
assert y_rewritten_start.owner.inputs[0] == x
default_mode = get_default_mode()
rewrite_mode = default_mode.excluding("local_Unique_scalar")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
x_val = np.array(-10.0, dtype=np.float64)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Alloc_lift(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
alloc(x, *new_shape),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Alloc_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Alloc) for node in y_rewritten_fg.apply_nodes)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
rewrite_mode = default_mode.excluding(
"local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift"
)
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `Alloc` is used to compute the reference `y`
# result
assert any(isinstance(node.op, Alloc) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_BroadcastTo(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
BroadcastTo()(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(
isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes
)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, repeats, repeat_axis",
[
(np.array([[-10, -3], [-10, 2]], dtype=np.int64), None, (1, 2), 0),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Repeat(
x_val,
unique_axis,
repeats,
repeat_axis,
return_index,
return_counts,
return_inverse,
):
x = as_tensor_variable(x_val).type()
y = unique(
repeat(x, tuple(repeats), axis=repeat_axis),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Repeat_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Repeat) for node in y_rewritten_fg.apply_nodes)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode = default_mode.excluding("local_Unique_Repeat_lift")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(isinstance(node.op, Repeat) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_second(
x_val, unique_axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
a = np.zeros(tuple(new_shape), dtype=x.dtype)
y = unique(
second(a, x),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_second_lift"],
exclude=["local_Unique_scalar", "topo_constant_folding"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
y_rewritten_start = y_rewritten_start.owner.inputs[0]
if y_rewritten_start.owner and isinstance(y_rewritten_start.owner.op, DimShuffle):
y_rewritten_start = y_rewritten_start.owner.inputs[0]
assert y_rewritten_start == x
assert not any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_rewritten_fg.apply_nodes
if isinstance(node.op, Elemwise)
)
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
y_fn = function([x], [y, y_rewritten], mode=Mode(optimizer=OPT_NONE))
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_fn.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
def test_local_merge_consecutive_specify_shape():
x = matrix()
s = at.as_tensor([iscalar(), iscalar()])
......@@ -3501,64 +1601,6 @@ def test_printing():
assert pprint(v) == "[a, b]"
def test_local_remove_scalar_BroadcastTo():
x = dscalar()
y = BroadcastTo()(x, ())
assert isinstance(y.owner.op, BroadcastTo)
res = rewrite_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
)
assert res is x
def test_local_useless_dimshuffle_makevector():
a = scalar()
x = MakeVector(config.floatX)(a)
y = x.dimshuffle(())
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_useless_dimshuffle_makevector"],
)
assert y_rewritten_fg.outputs[0] == a
def test_Shape_i_canonicalize():
"""Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension.
In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)``
and nothing else. The rewrites `local_shape_to_shape_i`,
`local_subtensor_remove_broadcastable_index`, and
`local_useless_dimshuffle_makevector` need to work together to accomplish
this, and we confirm that here.
"""
x = vector()
y = shape(x)[0]
y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()])
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=[
"canonicalize",
],
)
y_rewritten = y_rewritten_fg.outputs[0]
assert isinstance(y_rewritten.owner.op, Shape_i)
assert y_rewritten.owner.op.i == 0
assert y_rewritten.owner.inputs[0] == x
class TestLocalElemwiseAlloc:
"""
......@@ -3847,3 +1889,6 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.basic_opt import register_useless # noqa: F401 F811
with pytest.deprecated_call():
from aesara.tensor.rewriting.basic import ShapeFeature # noqa: F401
import contextlib
import numpy as np
import pytest
import aesara
import aesara.scalar as aes
import aesara.tensor as at
from aesara import shared
from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode
from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import check_stack_trace, out2in
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import Composite
from aesara.tensor.basic import MakeVector
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import (
add,
bitwise_and,
bitwise_or,
cos,
cosh,
dot,
eq,
exp,
int_div,
invert,
iround,
log,
log2,
log10,
mul,
neg,
neq,
)
from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import reciprocal
from aesara.tensor.math import round as at_round
from aesara.tensor.math import sin, sinh, sqr, sqrt
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
from aesara.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
from aesara.tensor.shape import reshape
from aesara.tensor.type import (
TensorType,
dmatrices,
dscalar,
dvector,
fscalar,
fvector,
matrix,
scalar,
tensor,
vector,
vectors,
)
from tests import unittest_tools as utt
dimshuffle_lift = out2in(local_dimshuffle_lift)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
x = TensorType(shape=xbc, dtype="float64")("x")
y = TensorType(shape=ybc, dtype="float64")("y")
z = TensorType(shape=zbc, dtype="float64")("z")
return x, y, z
class TestDimshuffleLift:
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = FunctionGraph([x], [e])
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
g = FunctionGraph([x], [e])
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
g = FunctionGraph([x], [e])
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty
def test_lift(self):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z
g = FunctionGraph([x, y, z], [e])
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
rewrite_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
rewrite_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift.rewrite(g)
assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_recursive_lift(self):
v = vector(dtype="float64")
m = matrix(dtype="float64")
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
init_str_g = (
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, (None,))>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, (None, None))>, "
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert str(g) == init_str_g
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
rewrite_str_g = (
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, (None, None))>), "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
)
assert str(new_g) == rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(new_g, ops_to_check="all")
def test_useless_dimshuffle(self):
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
def test_dimshuffle_on_broadcastable(self):
x, y, z = inputs([False, True], [True, False, True], [False, False, True])
u = at.constant(1)
ds_x = ds(x, (0, "x")) # useless
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # useful
ds_u = ds(u, ("x")) # useful
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift.rewrite(g)
assert (
str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
def test_local_useless_dimshuffle_in_reshape():
vec = TensorType(shape=(False,), dtype="float64")("vector")
mat = TensorType(shape=(False, False), dtype="float64")("mat")
row = TensorType(shape=(True, False), dtype="float64")("row")
col = TensorType(shape=(False, True), dtype="float64")("col")
reshape_dimshuffle_vector = reshape(vec.dimshuffle("x", 0), vec.shape)
reshape_dimshuffle_mat = reshape(mat.dimshuffle("x", 0, "x", 1), mat.shape)
reshape_dimshuffle_row = reshape(row.dimshuffle(1, "x"), row.shape)
reshape_dimshuffle_col = reshape(col.dimshuffle(0), col.shape)
g = FunctionGraph(
[vec, mat, row, col],
[
reshape_dimshuffle_vector,
reshape_dimshuffle_mat,
reshape_dimshuffle_row,
reshape_dimshuffle_col,
],
)
assert str(g) == (
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col)))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h
class TestFusion:
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(get_default_mode().linker, rewrites)
_shared = staticmethod(shared)
topo_exclude = ()
def my_init(dtype="float64", num=0):
return np.zeros((5, 5), dtype=dtype) + num
fw, fx, fy, fz = [
tensor(dtype="float32", shape=[False] * 2, name=n) for n in "wxyz"
]
dw, dx, dy, dz = [
tensor(dtype="float64", shape=[False] * 2, name=n) for n in "wxyz"
]
ix, iy, iz = [tensor(dtype="int32", shape=[False] * 2, name=n) for n in "xyz"]
fv = fvector("v")
fs = fscalar("s")
fwv = my_init("float32", 1)
fxv = my_init("float32", 2)
fyv = my_init("float32", 3)
fzv = my_init("float32", 4)
fvv = _asarray(np.random.random(5), dtype="float32")
fsv = np.asarray(np.random.random(), dtype="float32")
dwv = my_init("float64", 5)
ixv = _asarray(my_init(num=60), dtype="int32")
iyv = _asarray(my_init(num=70), dtype="int32")
izv = _asarray(my_init(num=70), dtype="int32")
fwx = fw + fx
ftanx = tan(fx)
@pytest.mark.parametrize(
"case",
[
(
fx + fy + fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + fzv,
"float32",
), # 0
(
fx * fy * fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv * fyv * fzv,
"float32",
), # 1
(
fx + fy * fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv * fzv,
"float32",
), # 2
(
fx * fy + fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv * fyv + fzv,
"float32",
), # 3
(
fw + fx + fy + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + fx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
), # 5
(
((fw + fx) + fy) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + (fx + fy)) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + (fx + fy) + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
fw + (fx + (fy + fz)),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
),
(
(fw + fx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv,
"float32",
), # 10
(
fw * fx * fy * fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv * fxv * fyv * fzv,
"float32",
),
(
fw + fx * fy * fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv * fyv * fzv,
"float32",
),
(
fx + fy * fz * fx,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv * fzv * fxv,
"float32",
),
(
fx * fy + fz + fy,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv * fyv + fzv + fyv,
"float32",
),
(
fx * fy * fz * fw + fx + fy + fz + fw,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv,
"float32",
), # 15
# test with constant
(
(fw + fx) + (fy + fz) + 2.0,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
((fw + fx) + 2.0 + fy) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
(fw + (fx + 2.0 + fy)) + fz,
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
(fw + (fx + fy) + 2 + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
(
fw + (fx + (fy + fz) + 2.0),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
), # 20
(
2 + (fw + fx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
1,
fwv + fxv + fyv + fzv + 2,
"float32",
),
# mix float32 and float64
(
2 + (dw + fx) + (fy + fz),
(dw, fx, fy, fz),
(dwv, fxv, fyv, fzv),
1,
dwv + fxv + fyv + fzv + 2,
"float64",
),
(
2 + (fw + dw) + (fy + fz),
(fw, dw, fy, fz),
(fwv, dwv, fyv, fzv),
1,
fwv + dwv + fyv + fzv + 2,
"float64",
),
(
2 + (fw + fx) + (dw + fz),
(fw, fx, dw, fz),
(fwv, fxv, dwv, fzv),
1,
fwv + fxv + dwv + fzv + 2,
"float64",
),
(
2 + (fw + fx) + (fy + dw),
(fw, fx, fy, dw),
(fwv, fxv, fyv, dwv),
1,
fwv + fxv + fyv + dwv + 2,
"float64",
), # 25
# test when their is other op then elemwise.
(
(fwx.sum()) + (fwx) + (fy + fz),
(fw, fx, fy, fz),
(fwv, fxv, fyv, fzv),
4,
(fwv + fxv).sum() + fwv + fxv + fyv + fzv,
"float32",
),
# test other elemwise op
(
fx + fy + cos(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.cos(fzv),
"float32",
),
(
fx + fy + cosh(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.cosh(fzv),
"float32",
),
(
fx + fy + abs(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.absolute(fzv),
"float32",
),
(
ix + iy + abs(iz),
(ix, iy, iz),
(ixv, iyv, izv),
1,
ixv + iyv + np.absolute(izv),
"int32",
), # 30
(
fx + fy + log(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.log(fzv),
"float32",
),
(
fx + fy + log2(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.log2(fzv),
"float32",
),
(
fx + fy + log10(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.log10(fzv),
"float32",
),
(
fx + fy**fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv**fzv,
"float32",
), # pow
(
fx + fy + exp(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv + fyv + np.exp(fzv),
"float32",
), # 35
(
fx - fy - fz,
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv - fzv,
"float32",
),
(
fx - (fy / fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv / fzv),
"float32",
),
(
fx - true_div(fy, 2),
(fx, fy),
(fxv, fyv),
1,
fxv - (fyv / 2),
"float32",
),
(
fx - true_div(fy, fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv / fzv),
"float32",
),
(
fx - int_div(ix * 100, iy * 1000),
(fx, ix, iy),
(fxv, ixv, iyv),
1,
fxv - ((ixv * 100) // (iyv * 1000)),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
), # 40
(fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"),
(
fx - (fy % fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv % fzv),
"float32",
),
(
fx - (fy > fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv > fzv),
"float32",
),
(
fx - (fy >= fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv >= fzv),
"float32",
),
(
fx - (fy < fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv < fzv),
"float32",
), # 45
(
fx - (fy <= fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv <= fzv),
"float32",
),
(
fx - eq(fy, fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv == fzv),
"float32",
),
(
fx - neq(fy, fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - (fyv != fzv),
"float32",
),
(
fx - fy + tan(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.tan(fzv),
"float32",
),
(
fx - fy + tanh(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.tanh(fzv),
"float32",
), # 50
(
fx - fy + sin(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.sin(fzv),
"float32",
),
(
fx - fy + sinh(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.sinh(fzv),
"float32",
),
(
fx - fy + sqr(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + (fzv * fzv),
"float32",
),
(
fx - fy + sqrt(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.sqrt(fzv),
"float32",
),
(
fx - fy + reciprocal(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + (1 / fzv),
"float32",
), # 55
(
fx - fy + neg(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + (-fzv),
"float32",
),
(
fx - fy + at_round(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
fxv - fyv + np.round(fzv),
"float32",
),
(
ix - iy + iround(fz),
(ix, iy, fz),
(ixv, iyv, fzv),
1,
ixv - iyv + np.round(fzv),
"int64",
),
# Bit op
(
fx - bitwise_or(iy, iz),
(fx, iy, iz),
(fxv, iyv, izv),
1,
fxv - (iyv | izv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
),
(
fx - xor(iy, iz),
(fx, iy, iz),
(fxv, iyv, izv),
1,
fxv - (iyv ^ izv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
), # 60
(
fx - bitwise_and(iy, iz),
(fx, iy, iz),
(fxv, iyv, izv),
1,
fxv - (iyv & izv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
),
(
fx - invert(iy),
(fx, iy),
(fxv, iyv),
1,
fxv - (~iyv),
{
"custom": "float64",
"numpy + floatX": config.floatX,
"numpy": "float64",
},
),
(
fx - at.cast(fy, dtype="float64"),
(fx, fy),
(fxv, fyv),
1,
fxv - np.asarray(fyv, "float64"),
"float64",
),
(
at_pow(fx * fy + fz, fx * fy),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
np.power(fxv * fyv + fzv, fxv * fyv),
"float32",
),
(
fv + fy**fz,
(fv, fy, fz),
(fvv, fyv, fzv),
2,
fvv + fyv**fzv,
"float32",
), # fused with a dimshuffle #65
(
fv - fy + tanh(fz),
(fv, fy, fz),
(fvv, fyv, fzv),
2,
fvv - fyv + np.tanh(fzv),
"float32",
), # fused with a dimshuffle
# Cases where the same input is reused many times.
(
mul(fx, fx, fx, fx),
(fx,),
(fxv,),
1,
fxv * fxv * fxv * fxv,
"float32",
),
(
mul(fx, ftanx, ftanx),
(fx,),
(fxv,),
1,
fxv * np.tan(fxv) * np.tan(fxv),
"float32",
),
(
mul(fx, ftanx, ftanx, fx),
(fx,),
(fxv,),
1,
fxv * np.tan(fxv) * np.tan(fxv) * fxv,
"float32",
),
(
mul(ftanx, ftanx, fx + fy),
(fx, fy),
(fxv, fyv),
1,
np.tan(fxv) * np.tan(fxv) * (fxv + fyv),
"float32",
), # 70
# Cases with different broadcast pattern. They should not
# be merged as this would duplicate computation
# The graph should have 2 elemwise and 1 dimshuffle
(
fx * sin(fs),
(fx, fs),
(fxv, fsv),
3,
fxv * np.sin(fsv),
"float32",
),
],
)
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
"""Verify that `Elemwise` fusion works."""
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat):
out = f(*val_inputs)
else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
for x in range(nb_repeat):
f(*val_inputs)
out = out.get_value()
atol = 1e-8
if out_dtype == "float32":
atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
if assert_len_topo:
assert len(topo_) == nb_elemwise
if nb_elemwise == 1:
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs
)
assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype
def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
inpts = vectors(["i%i" % i for i in range(35)])
# Make an elemwise graph looking like:
# sin(i34 + sin(i33 + sin(... i1 + sin(i0) ...)))
out = sin(inpts[0])
for idx in range(1, 35):
out = sin(inpts[idx] + out)
with config.change_flags(cxx=""):
f = function(inpts, out, mode=self.mode)
# Make sure they all weren't fused
composite_nodes = [
node
for node in f.maker.fgraph.toposort()
if isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite)
]
assert not any(len(node.inputs) > 31 for node in composite_nodes)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_big_fusion(self):
# In the past, pickle of Composite generated in that case
# crashed with max recursion limit. So we were not able to
# generate C code in that case.
factors = []
sd = dscalar()
means = dvector()
cst_05 = at.constant(0.5)
cst_m05 = at.constant(-0.5)
cst_2 = at.constant(2)
cst_m2 = at.constant(-2)
ones = at.constant(np.ones(10))
n = 85
if config.mode in ["DebugMode", "DEBUG_MODE"]:
n = 10
for i in range(n):
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
cst_05 * (sd**cst_m2) / np.pi
)
factors.append(at_sum(f))
logp = add(*factors)
vars = [sd, means]
# Make sure that C compilation is used
mode = Mode("cvm", self.rewrites)
dlogp = function(vars, [aesara.grad(logp, v) for v in vars], mode=mode)
# Make sure something was fused
assert any(
isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite)
for node in dlogp.maker.fgraph.toposort()
)
def test_add_mul_fusion_inplace(self):
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, rewrites)
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=mode)
topo = [n for n in f.maker.fgraph.toposort()]
assert len(topo) == 2
assert topo[-1].op.inplace_pattern
new_out = f.maker.fgraph.outputs[0]
assert isinstance(new_out.owner.op, Elemwise)
assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add)
assert len(new_out.owner.inputs) == 4
# TODO: Do we really need to do this?
_ = f(
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
# This custom `Op` has no `c_code` method
class NoCCodeOp(aes.basic.UnaryScalarOp):
def impl(self, x):
return x * 2
no_c_code_op = Elemwise(NoCCodeOp(aes.basic.upgrade_to_float))
mode = Mode(linker="cvm")
mode._optimizer = mode._optimizer.including(
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
)
x = vector()
out = x * no_c_code_op(x + 1)
f = function([x], out, mode=mode)
assert not any(
isinstance(getattr(n.op, "scalar_op"), aes.basic.Composite)
for n in f.maker.fgraph.toposort()
)
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""
rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, rewrites)
x, y, z = dmatrices("xyz")
x.tag.test_value = test_value
y.tag.test_value = test_value
z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=mode)
if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
class TimesN(aes.basic.UnaryScalarOp):
"""
Used in test TestCompositeCodegen
Must be outside of the class, otherwise, the c cache code can't
pickle this class and this cause stuff printing during test.
"""
def __eq__(self, other):
return super().__eq__(other) and self.n == other.n
def __hash__(self):
return super().__hash__() ^ hash(self.n)
def __init__(self, n, *args, **kwargs):
self.n = n
aes.basic.UnaryScalarOp.__init__(self, *args, **kwargs)
def impl(self, x):
return x * self.n
def c_support_code_apply(self, node, nodename):
n = str(self.n)
return (
"""
float %(nodename)s_timesn(float x) { return x * %(n)s; }
"""
% locals()
)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
return f"{z} = {name}_timesn({x});"
class TestCompositeCodegen:
"""
Test The Composite Ops code generation in a case where there is multiple
scalar ops with support code.
"""
def setup_method(self):
upgrade_to_float = aes.basic.upgrade_to_float
self.scal_times_2 = TimesN(2, upgrade_to_float, name="times_2")
self.times_2 = Elemwise(self.scal_times_2, name="times_2")
self.scal_times_3 = TimesN(3, upgrade_to_float, name="times_3")
self.times_3 = Elemwise(self.scal_times_3, name="times_3")
self.x = fvector()
def test_nested_composite(self):
y = self.times_2(self.x)
z = self.times_3(y)
f = function([self.x], z)
if config.mode != "FAST_COMPILE":
assert len(f.maker.fgraph.toposort()) == 1
fval = f([1, 2, 3])
assert np.all(fval == [6, 12, 18])
def test_local_useless_composite(self):
x = aes.float32()
c = aes.Composite([x], [x + 1, x - 1])
X = matrix()
o = Elemwise(scalar_op=c)(X)
mode = get_default_mode().including("local_useless_composite")
f = function([X], o[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[2.0]])
f = function([X], o[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]]), [[0.0]])
def test_local_useless_dimshuffle_makevector():
a = scalar()
x = MakeVector(config.floatX)(a)
y = x.dimshuffle(())
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_useless_dimshuffle_makevector"],
)
assert y_rewritten_fg.outputs[0] == a
import numpy as np
import pytest
import aesara.scalar as aes
from aesara.compile.function import function
from aesara.compile.mode import OPT_NONE, Mode, get_default_mode
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.basic import Alloc, alloc, as_tensor_variable, second
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from aesara.tensor.type import dscalar
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_scalar(return_index, return_counts, return_inverse):
x = dscalar()
y = unique(
x,
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=None,
)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"]
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, DimShuffle)
assert y_rewritten_start.owner.inputs[0] == x
default_mode = get_default_mode()
rewrite_mode = default_mode.excluding("local_Unique_scalar")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
x_val = np.array(-10.0, dtype=np.float64)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Alloc_lift(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
alloc(x, *new_shape),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Alloc_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Alloc) for node in y_rewritten_fg.apply_nodes)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
rewrite_mode = default_mode.excluding(
"local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift"
)
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `Alloc` is used to compute the reference `y`
# result
assert any(isinstance(node.op, Alloc) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_BroadcastTo(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
BroadcastTo()(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(
isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes
)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, repeats, repeat_axis",
[
(np.array([[-10, -3], [-10, 2]], dtype=np.int64), None, (1, 2), 0),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Repeat(
x_val,
unique_axis,
repeats,
repeat_axis,
return_index,
return_counts,
return_inverse,
):
x = as_tensor_variable(x_val).type()
y = unique(
repeat(x, tuple(repeats), axis=repeat_axis),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Repeat_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Repeat) for node in y_rewritten_fg.apply_nodes)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode = default_mode.excluding("local_Unique_Repeat_lift")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(isinstance(node.op, Repeat) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_second(
x_val, unique_axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
a = np.zeros(tuple(new_shape), dtype=x.dtype)
y = unique(
second(a, x),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_second_lift"],
exclude=["local_Unique_scalar", "topo_constant_folding"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
y_rewritten_start = y_rewritten_start.owner.inputs[0]
if y_rewritten_start.owner and isinstance(y_rewritten_start.owner.op, DimShuffle):
y_rewritten_start = y_rewritten_start.owner.inputs[0]
assert y_rewritten_start == x
assert not any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_rewritten_fg.apply_nodes
if isinstance(node.op, Elemwise)
)
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
y_fn = function([x], [y, y_rewritten], mode=Mode(optimizer=OPT_NONE))
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_fn.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
def test_local_remove_scalar_BroadcastTo():
x = dscalar()
y = BroadcastTo()(x, ())
assert isinstance(y.owner.op, BroadcastTo)
res = rewrite_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
)
assert res is x
......@@ -79,7 +79,7 @@ from aesara.tensor.math import round as at_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.basic import local_dimshuffle_lift
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
from aesara.tensor.rewriting.math import (
compute_mul,
is_1pexp,
......
import copy
import numpy as np
import pytest
import aesara.tensor as at
from aesara import shared
from aesara.compile.function import function
from aesara.compile.mode import get_default_mode, get_mode
from aesara.compile.ops import deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.type import Type
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import add, exp, maximum
from aesara.tensor.rewriting.basic import register_specialize
from aesara.tensor.rewriting.shape import (
ShapeFeature,
local_reshape_to_dimshuffle,
local_useless_reshape,
)
from aesara.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
reshape,
shape,
specify_shape,
)
from aesara.tensor.subtensor import set_subtensor
from aesara.tensor.type import (
fmatrix,
iscalar,
lscalar,
matrix,
scalar,
tensor,
tensor3,
tensor4,
vector,
)
from tests import unittest_tools as utt
rewrite_mode = config.mode
if rewrite_mode == "FAST_COMPILE":
rewrite_mode = "FAST_RUN"
rewrite_mode = get_mode(rewrite_mode)
class TestShapeRewriter:
def test_basic(self):
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
v = vector()
m = matrix()
f = function([v, m], (v + m).shape, mode=mode)
for node in f.maker.fgraph.toposort():
assert node.op != add
def test_constant(self):
mode = config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
v = vector()
f = function([v], v.dimshuffle("x", "x", 0).shape[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
@staticmethod
def max_pool_c01b(c01b, pool_shp, pool_stride, img_shp):
"""
Like max_pool but with input using axes ('c', 0, 1, 'b')
(Alex Krizhevsky format)
pool_shp, pool_stride and img_shp are int that represent
the same shp in x and y.
"""
mx = None
# Compute index in pooled space of last needed pool
# (needed = each input pixel must appear in at least one pool)
def last_pool(im_shp, p_shp, p_strd):
rval = int(np.ceil(float(im_shp - p_shp) / p_strd))
assert p_strd * rval + p_shp >= im_shp
assert p_strd * (rval - 1) + p_shp < im_shp
return rval
# Compute starting row of the last pool
last_pool_r = last_pool(img_shp, pool_shp, pool_stride) * pool_stride
# Compute number of rows needed in img for all indexes to work out
required_r = last_pool_r + pool_shp
last_pool_c = last_pool(img_shp, pool_shp, pool_stride) * pool_stride
required_c = last_pool_c + pool_shp
wide_infinity = at.alloc(
-np.inf, c01b.shape[0], required_r, required_c, c01b.shape[3]
)
c01b = set_subtensor(wide_infinity[:, 0:img_shp, 0:img_shp, :], c01b)
for row_within_pool in range(pool_shp):
row_stop = last_pool_r + row_within_pool + 1
for col_within_pool in range(pool_shp):
col_stop = last_pool_c + col_within_pool + 1
cur = c01b[
:,
row_within_pool:row_stop:pool_stride,
col_within_pool:col_stop:pool_stride,
:,
]
if mx is None:
mx = cur
else:
mx = maximum(mx, cur)
return mx
def test_broadcasted_dims(self):
# This test a case that caused a crash during rewriting
shp = (1, 1, 1, 1)
rng = np.random.default_rng(utt.fetch_seed())
a = shared(rng.random(shp).astype(config.floatX))
out = self.max_pool_c01b(a, 1, 1, 1)
# max_pool_c01b use -inf and this will trigger DebugMode error.
mode = copy.copy(get_default_mode())
mode.check_isfinite = False
f = function([], out, mode=mode)
f()
def test_constant_merge(self):
# This test the error in gh-1122 that is a caused by the
# combination of merge rewriter and ShapeFeature.
x = at.constant([0, 0])
y = x[1:]
x1 = x - at.join(0, y, y)
x1.eval()
def test_local_track_shape_i(self):
class IdentityNoShape(Op):
"""Op that does not infer the output shape from the input one"""
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x.copy()
# def infer_shape(self, fgraph, node, (xshp,)):
# return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])]
identity_noshape = IdentityNoShape()
class IdentityShape(Op):
"""Op that does infer the output shape from the input one"""
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x.copy()
def infer_shape(self, fgraph, node, xshp_):
# Could also just return.
(xshp,) = xshp_
return (xshp,)
identity_shape = IdentityShape()
@node_rewriter([IdentityNoShape])
def local_identity_noshape_to_identity_shape(fgraph, node):
"""Transform the first `Op` into the second."""
if isinstance(node.op, IdentityNoShape):
return [identity_shape(node.inputs[0])]
mode = get_default_mode().including("ShapeOpt", "specialize")
rng = np.random.default_rng(utt.fetch_seed())
x = tensor3("x")
ins_x = identity_noshape(x)
# Without the rewrite
f = function([x], ins_x.shape, mode=mode)
xval = rng.standard_normal((3, 4, 7)).astype(config.floatX)
assert np.all(f(xval) == [3, 4, 7])
f_ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(f_ops) == 5
assert identity_noshape in f_ops
assert identity_shape not in f_ops
# Register the rewrite
register_specialize(local_identity_noshape_to_identity_shape)
mode = get_default_mode().including("ShapeOpt", "specialize")
# The `identity_shape` hOph should not be needed anymore to compute
# the shape
g = function([x], ins_x.shape, mode=mode)
xval = rng.standard_normal((6, 1, 2)).astype(config.floatX)
assert np.all(g(xval) == [6, 1, 2])
g_ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(g_ops) == 4
assert identity_noshape not in g_ops
assert identity_shape not in g_ops
# Test multiple applications of an `Op` without an `Op.infer_shape`
ins_x3 = identity_noshape(identity_noshape(identity_noshape(x)))
h = function([x], ins_x3.shape, mode=mode)
xval = rng.standard_normal((6, 1, 2)).astype(config.floatX)
assert np.all(h(xval) == [6, 1, 2])
h_ops = [node.op for node in h.maker.fgraph.toposort()]
assert len(h_ops) == 4
assert identity_noshape not in h_ops
assert identity_shape not in h_ops
def test_no_shapeopt(self):
"""Test that a basic example works even when `ShapeOpt` is excluded."""
X = matrix()
expr = X.shape[0]
mode = get_default_mode().excluding("ShapeOpt")
f = function([X], expr, mode=mode)
# FIXME: This is not a good test.
f([[1, 2], [2, 3]])
class TestReshape:
def setup_method(self):
self.mode = rewrite_mode
self.op = Reshape
def test_local_reshape(self):
a = fmatrix()
b = self.op(3)(a, [2, 3, 4])
c = self.op(1)(b, [24])
f = function([a], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, self.op) for node in topo) == 1
# Check stack trace
assert check_stack_trace(f, ops_to_check=[self.op])
class TestLocalUselessReshape:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
def test_0(self):
mode = get_default_mode().including("local_useless_reshape")
i = iscalar("i")
m = at.mgrid[
0:i,
]
f = function([i], m, mode=mode)
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
def test_1(self):
x = matrix("x")
r = x.reshape(x.shape)
m0 = get_default_mode()
m1 = m0.including("local_useless_reshape")
f1 = function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
m2 = m1.excluding("ShapeOpt")
f2 = function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
# We do not need tests checking that stack traces are copied over,
# because local_useless_reshape only removes nodes from the graph
def test_2(self):
x = matrix("x")
r = x.reshape([Shape_i(i)(x) for i in range(x.ndim)])
m0 = get_default_mode()
m1 = m0.including("local_useless_reshape")
f1 = function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
m2 = m1.excluding("ShapeOpt")
f2 = function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
def test_m1(self):
x = matrix("x")
r = x.reshape((x.shape[0], -1))
m0 = get_default_mode()
m1 = m0.including("local_useless_reshape")
f1 = function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
m2 = m1.excluding("ShapeOpt")
f2 = function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, Reshape) for n in topo)
class TestLocalReshapeToDimshuffle:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
def test_1(self):
reshape_lift = out2in(local_reshape_to_dimshuffle)
useless_reshape = out2in(local_useless_reshape)
x = shared(self.rng.standard_normal((4,)))
y = shared(self.rng.standard_normal((5, 6)))
reshape_x = reshape(x, (1, 4))
reshape_y = reshape(y, (1, 5, 1, 6, 1, 1))
g = FunctionGraph([x, y], [reshape_x, reshape_y])
assert str(g) == (
"FunctionGraph(Reshape{2}"
"(<TensorType(float64, (None,))>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, (None, None))>, "
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift.rewrite(g)
useless_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,0,x,1,x,x}"
"(Reshape{2}(<TensorType(float64, (None, None))>, "
"TensorConstant{[5 6]})))"
)
# Check stacktrace was copied over correctly after the rewrite was applied
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
def test_local_reshape_lift():
x = tensor4()
out = exp(x).reshape([x.size])
assert out.ndim == 1
mode = get_default_mode()
mode = mode.including("local_reshape_lift")
f = function([x], out, mode=mode)
f(np.random.random((5, 4, 3, 2)).astype(config.floatX))
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-2].op, Reshape)
assert isinstance(topo[-1].op, Elemwise)
assert check_stack_trace(f, ops_to_check="last")
class TestShapeI(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
def test_perform(self):
rng = np.random.default_rng(utt.fetch_seed())
advec = vector()
advec_val = rng.random((3)).astype(config.floatX)
f = function([advec], Shape_i(0)(advec))
out = f(advec_val)
utt.assert_allclose(out, advec_val.shape[0])
admat = matrix()
admat_val = rng.random((4, 3)).astype(config.floatX)
for i in range(2):
f = function([admat], Shape_i(i)(admat))
out = f(admat_val)
utt.assert_allclose(out, admat_val.shape[i])
def test_infer_shape(self):
admat = matrix()
admat_val = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check([admat], [Shape_i(0)(admat)], [admat_val], Shape_i)
self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i)
class TestSameShape:
def test_scalar(self):
x = scalar()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector(self):
x = vector()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_no_static_shapes(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
# We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)
@pytest.mark.parametrize(
"y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
y = at.tensor(dtype="floatX", shape=(y_dim_0, None))
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0)
assert not shape_feature.same_shape(x, o, 1, 1)
def test_vector_dim_err(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
with pytest.raises(IndexError):
shape_feature.same_shape(x, o, 1, 0)
with pytest.raises(IndexError):
shape_feature.same_shape(x, o, 0, 1)
@pytest.mark.parametrize(
"shape",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape(shape):
x = vector()
s = specify_shape(x, shape).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert shape in fgraph.variables
@pytest.mark.parametrize(
"s1",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape_partial(s1):
x = matrix()
s = specify_shape(x, (s1, None)).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = rewrite_graph(fgraph, clone=False)
assert x in fgraph.variables
assert s1 in fgraph.variables
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True])
s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1
# A test for a non-`TensorType`
class MyType(Type):
ndim = 1
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
class MyVariable(Variable):
pass
x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = rewrite_graph(fgraph, clone=False)
assert fgraph.outputs[0] == s
def test_Shape_i_canonicalize():
"""Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension.
In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)``
and nothing else. The rewrites `local_shape_to_shape_i`,
`local_subtensor_remove_broadcastable_index`, and
`local_useless_dimshuffle_makevector` need to work together to accomplish
this, and we confirm that here.
"""
x = vector()
y = shape(x)[0]
y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()])
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=[
"canonicalize",
],
)
y_rewritten = y_rewritten_fg.outputs[0]
assert isinstance(y_rewritten.owner.op, Shape_i)
assert y_rewritten.owner.op.i == 0
assert y_rewritten.owner.inputs[0] == x
......@@ -18,9 +18,9 @@ from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.exceptions import ShapeError
from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any
from aesara.tensor.rewriting.basic import ShapeError
from aesara.tensor.type import (
TensorType,
bmatrix,
......
......@@ -12,7 +12,7 @@ from aesara.misc.safe_asarray import _asarray
from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.shape import (
Reshape,
Shape_i,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论