提交 5f809cfe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify rewrites by assuming Elemwise / Alloc shapes are correct

上级 2c4a3e7b
......@@ -23,7 +23,7 @@ Many stabilize and stabilization rewrites refuse to be applied when a variable h
"""
import logging
from typing import TYPE_CHECKING, Optional, Union
from typing import Union
import numpy as np
......@@ -65,21 +65,17 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.var import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter
if TYPE_CHECKING:
from pytensor.tensor.rewriting.shape import ShapeFeature
_logger = logging.getLogger("pytensor.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter())
......@@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node):
def local_elemwise_alloc(fgraph, node):
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
`Alloc`\s are effectively a type of `Elemwise` operation
(e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
`Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
broadcasts).
In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
`Alloc`\s.
The rewrite essentially performs the following replacement:
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
when ``y.shape`` for some input ``y`` (or the combined shapes of the
non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
In it's current form, it also explicitly accounts for `DimShuffle`\s of
In its current form, it also explicitly accounts for `DimShuffle`\s of
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions.
"""
# Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1:
return False
if len(node.outputs) > 1:
return False
return None
def dimshuffled_alloc(i):
return (
......@@ -305,76 +286,40 @@ def local_elemwise_alloc(fgraph, node):
if len(alloc_idxs) == 0:
return False
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions.
ref_var_idx = None
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be rewritten.
if idx not in alloc_idxs:
ref_var_idx = idx
break
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if ref_var_idx is None:
for idx, i in enumerate(node.inputs):
# XXX: This broadcastable comparison doesn't work
if (
i.type.broadcastable == node.outputs[0].type.broadcastable
) and idx in alloc_idxs:
ref_var_idx = idx
break
if not hasattr(fgraph, "shape_feature"):
return False
input_shapes = [
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
for i in node.inputs
]
bcasted_shape = broadcast_shape(
*input_shapes,
arrays_are_shapes=True,
)
new_inputs = list(node.inputs)
for idx in alloc_idxs:
i = node.inputs[idx]
# Remove `Alloc`
# Remove simple `Alloc`
if isinstance(i.owner.op, Alloc):
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
new_inp = i.owner.inputs[0]
# TODO FIXME: This shouldn't be handled here.
# `DimShuffle`s should be lifted through `Alloc`s
# by other, more general rewrites.
# Remove `Alloc` in `DimShuffle`
# Remove `Dimshuffle(Alloc)`
elif isinstance(i.owner.op, DimShuffle):
old_alloc = i.owner.inputs[0]
new_alloc = old_alloc.owner.inputs[0]
old_alloc_inp = old_alloc.owner.inputs[0]
missing_ndims = old_alloc.type.ndim - old_alloc_inp.type.ndim
if missing_ndims > 0:
# The `Alloc` added new dimensions to the left.
# We replace those cases with a `DimShuffle` here.
# Nested dimshuffles will be merged later by other rewrites.
old_alloc_inp = shape_padleft(old_alloc_inp, missing_ndims)
# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
if new_alloc.ndim != old_alloc.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
new_alloc = new_alloc.dimshuffle(
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
)
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
new_inp = i.owner.op(old_alloc_inp)
copy_stack_trace(i, new_alloc)
new_inputs[idx] = new_alloc
copy_stack_trace(i, new_inp)
new_inputs[idx] = new_inp
# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in cyclical merge rewrites.
if all(new is old for new, old in zip(new_inputs, node.inputs)):
return
new_outs = node.op(*new_inputs, return_list=True)
ret = node.op(*new_inputs, return_list=True)
copy_stack_trace(node.outputs, ret)
return ret
if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable:
new_outs = [
alloc_like(new_out, node.outputs[0], fgraph) for new_out in new_outs
]
copy_stack_trace(node.outputs, new_outs)
return new_outs
@register_canonicalize("shape_unsafe")
......@@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node):
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
# TODO: This should just be a WalkingGraphRewrite!
replacements = {node.outputs[0]: c}
for client, cl_idx in fgraph.clients[node.outputs[0]]:
if (
......@@ -438,9 +384,8 @@ def local_fill_to_alloc(fgraph, node):
with their dependencies on those tensors' shapes, and sometimes those
shapes can be computed without needing to compute the tensors themselves.
XXX: This rewrite can produce inconsistent results, so do *not* consider
making it a canonicalization until those inconsistencies are
resolved/justified.
Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
which could mask shape errors.
"""
shape_ref, values_ref = node.inputs
out_type = node.outputs[0].type
......@@ -448,13 +393,6 @@ def local_fill_to_alloc(fgraph, node):
if values_ref.type.broadcastable == out_type.broadcastable:
# The assumption here is that `values_ref` already has the same shape
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
# XXX FIXME TODO: The only way this can be determined is if one
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
# equal.
# This is an old rewrite, and it's only a
# "specialization/stabilization", so we're going to leave it be for
# now.
return [values_ref]
if shape_ref.type.broadcastable == out_type.broadcastable:
......@@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node):
copy_stack_trace(node.outputs[0], o)
return [o]
# The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
# TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))
return
......@@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node):
return [element_sum]
@register_useless("local_remove_switch_const_cond")
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
@register_specialize
@node_rewriter([Elemwise])
@register_useless("shape_unsafe")
@register_canonicalize("fast_compile", "shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([switch])
def local_useless_switch(fgraph, node):
"""
This rewrite makes the following changes in a graph:
at.switch(cond, left, right) ->
switch(cond, left, right) ->
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
if left is right -> left
and
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
if not isinstance(node.op.scalar_op, aes.Switch):
return False
shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None)
if shape_feature is None:
return False
left = node.inputs[1]
right = node.inputs[2]
cond_var = node.inputs[0]
cond = extract_constant(cond_var, only_process_constants=True)
out_bcast = node.outputs[0].type.broadcastable
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, (np.number, np.bool_)
......@@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node):
else:
out = correct_out
input_shapes = [
tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim))
for inp in node.inputs
]
out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True)
out = alloc(out, *out_shape)
if out.type.broadcastable != out_bcast:
out = broadcast_arrays(out, *node.inputs)[0]
# Copy over stacktrace from selected output to new output
copy_stack_trace(node.outputs + correct_out, out)
......@@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node):
if left == right:
# Note: No need to copy over stacktrace, because the input node
# already has its own stacktrace
if cond.type.is_super(left.type):
if left.type.broadcastable == out_bcast:
return [left]
ret = fill(cond, left)
ret = broadcast_arrays(left, cond)[0]
# Copy over stacktrace from switch output and correct branch
copy_stack_trace(node.outputs + left, ret)
......
......@@ -1013,7 +1013,7 @@ class TestLocalUselessSwitch:
z = at.switch(1, x, y)
f = function([x, y], z, mode=self.mode)
start_var = f.maker.fgraph.outputs[0].owner.inputs[0]
start_var = f.maker.fgraph.outputs[0]
assert isinstance(start_var.owner.op, Elemwise)
assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast)
assert not any(node.op == at.switch for node in f.maker.fgraph.toposort())
......@@ -1698,45 +1698,50 @@ class TestLocalElemwiseAlloc:
)
@pytest.mark.parametrize(
"expr, x_shape, y_shape",
"expr, x_shape, y_shape, needs_alloc",
[
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)),
(lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2), True),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1), False),
(lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3), False),
(
lambda x, y: at.mul(
at.alloc(x, 3).dimshuffle("x", 0), y.dimshuffle("x", "x")
),
(),
(),
True,
),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), (), True),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1), False),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2), False),
(
lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)),
(15, 1),
(15, 1),
False,
),
(
lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)),
(15, 2),
(15, 2),
False,
),
(
lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y),
(15, 2),
(2, 15),
False,
),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2), False),
(
lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y),
(15, 2),
(2, 15),
False,
),
],
)
def test_basic(self, expr, x_shape, y_shape):
def test_basic(self, expr, x_shape, y_shape, needs_alloc):
x = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
)
......@@ -1752,10 +1757,16 @@ class TestLocalElemwiseAlloc:
on_unused_input="ignore",
)
assert not any(
isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort()
)
nodes = z_opt.maker.fgraph.toposort()
if needs_alloc:
# When the final result needs an Alloc, this should be the last node
# x = scalar; y = vector; mul(x, ones_like(y)) -> alloc(x, y.shape)
assert isinstance(nodes[-1].op, Alloc)
nodes = nodes[:-1]
assert not any(isinstance(node.op, Alloc) for node in nodes)
# Check results are the same without the optimization
z_no_opt = pytensor.function(
[x, y],
z,
......@@ -1799,7 +1810,7 @@ class TestLocalElemwiseAlloc:
[self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
self.verify_op_count(func, 1, SpecifyShape)
func = function(
[self.vec, self.mat],
......@@ -1807,7 +1818,7 @@ class TestLocalElemwiseAlloc:
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
self.verify_op_count(func, 1, SpecifyShape)
# No optimization on alloc without assert
func = function(
......@@ -1839,7 +1850,10 @@ class TestLocalElemwiseAlloc:
self.alloc_w_dep_broad2 + self.mat,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
# This graph requires one outer Alloc and an Assert
# To make sure `mat` is square since we end up doing
# broadcast_to(x, mat[..., None].shape) + mat[None, ...]
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 1, Assert)
def test_remove_alloc_w_dimshuffle(self):
......@@ -1851,16 +1865,13 @@ class TestLocalElemwiseAlloc:
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# TODO FIXME: The `BroadcastTo` shapes should use the constants
# provided by the first/`Alloc` term, and not the unknown values from
# the `tens` term.
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
self.verify_op_count(func, 1, SpecifyShape)
func = function(
[self.vec, self.tens],
......@@ -1888,16 +1899,13 @@ class TestLocalElemwiseAlloc:
self.verify_op_count(func, 2, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle with assert
# TODO: When we support static shape constraints like `shape[i] != 1`,
# reproduce this with such a constraint on `mat` and make sure the
# `BroadcastTo` is removed.
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
# It still needs an outer alloc to broadcast final shape
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# No optimization on dimshuffle without assert
......@@ -1909,25 +1917,24 @@ class TestLocalElemwiseAlloc:
self.verify_op_count(func, 2, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
# The second assert is from the shape check...
self.verify_op_count(func, 2, Assert)
# It still needs an outer alloc to broadcast final shape
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
def test_misc(self):
x = row(dtype=self.dtype)
y = tensor(dtype=self.dtype, shape=(None, None, 1))
x = row("x", dtype=self.dtype)
y = tensor("y", dtype=self.dtype, shape=(None, None, 1))
out = at.alloc(x, 5, 5).dimshuffle(0, 1, "x") + y
func = function([y, x], out, mode=self.fast_run_mode)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
self.verify_op_count(func, 1, SpecifyShape)
y_val = np.random.random((5, 5, 1)).astype(self.dtype)
x_val = np.random.random((1, 5)).astype(self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论