提交 65b96c1c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Canonicalize squeeze out of reshape and specialize back

上级 dbf5f38e
......@@ -36,6 +36,7 @@ from pytensor.tensor.rewriting.basic import (
register_useless,
topo_constant_folding,
)
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import (
Reshape,
Shape,
......@@ -757,40 +758,36 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_useless
@register_canonicalize
@node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node):
def local_useless_expand_dims_in_reshape(fgraph, node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
Removes useless expand_dims `DimShuffle` operations inside Reshape:
reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
Implicit (and useless) squeezes are kept in the graph, as they are
part of the canonical form of the graph.
"""
dimshuffled_x, new_shape = node.inputs
expanded_x, new_shape = node.inputs
if not (
dimshuffled_x.owner is not None
and isinstance(dimshuffled_x.owner.op, DimShuffle)
expanded_x.owner is not None
and isinstance(expanded_x.owner.op, DimShuffle)
and expanded_x.owner.op.augment
):
return False
[inp] = dimshuffled_x.owner.inputs
new_order = dimshuffled_x.owner.op.new_order
new_order_of_nonbroadcast = []
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
if s != 1:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in range(len(new_order_of_nonbroadcast) - 1)
)
if no_change_in_order:
ret = inp.reshape(new_shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
[x] = expanded_x.owner.inputs
new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x")
if new_order != tuple(range(x.type.ndim)):
x = x.dimshuffle(new_order)
new_reshaped_x = x.reshape(new_shape)
copy_stack_trace(node.outputs[0], new_reshaped_x)
return [new_reshaped_x]
@register_canonicalize("shape_unsafe")
......@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node):
shape_feature = getattr(fgraph, "shape_feature", None)
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
# or cases where all but one dimension are provably preserved
# Match case where at least (n-1) entries correspond to the original shape:
# Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]])
# Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape.
output_shape_is = _unpack_shape_vector(output_shape)
nb_m1 = 0
shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim):
......@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node):
nb_m1 += 1
if nb_m1 <= 1 and all(shape_match):
return [inp]
return [inp] # This is provably correct
# There is one missing match, but all other dimensions match
# Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y)
if (nb_m1 == 0) and (shape_match.count(False) == 1):
return [inp]
return [inp] # This could mask a shape error
return False
@register_canonicalize
@register_canonicalize("shape_unsafe")
@node_rewriter([Reshape])
def local_reshape_to_dimshuffle(fgraph, node):
r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s.
r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions.
The goal is to avoid using `Reshape` to add or remove broadcastable
dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can
cancel out and/or be removed later on.
It's always valid to squeeze an input before doing the same reshape operation.
Equivalently, it's always valid to remove `1` entries from the reshape shape
and replace them by an expand_dims after the rewritten reshape operation.
We chose to canonicalize the graph in this way as it allows isolating
operations that are unique to the reshaping operation (mixing dimensions)
from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims).
This can allow further simplifications by other rewrites that target
DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations.
For example:
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
- reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
- reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
- reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
"""
inp, output_shape = node.inputs
[output] = node.outputs
unpacked_shape = _unpack_shape_vector(output_shape)
expand_axes = []
new_output_shape = []
for i, dim in enumerate(unpacked_shape):
if isinstance(dim, Constant) and dim.data == 1:
expand_axes.append(i)
else:
new_output_shape.append(dim)
# Remove any broadcastable dimensions from the input
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
# Trivial case, all dimensions of input/output are known to be broadcastable:
# there's nothing to reshape
if all(inp.type.broadcastable) or all(output.type.broadcastable):
new_output_shape = []
expand_axes = tuple(range(output.type.ndim))
else:
unpacked_shape = _unpack_shape_vector(output_shape)
new_output_shape = []
expand_axes = []
for i, dim_length in enumerate(unpacked_shape):
if isinstance(dim_length, Constant) and (
dim_length.data == 1
# -1 can be an implicit expand_dims, but it's tricky to prove
# as we would need to check whether all other dimensions
# already explain the full size of the array.
# Example: np.zeros((2, 2, 2)).reshape((8, -1))
# We rely on the output static shape which will already have figured
# it out for some (but not all) cases
or (dim_length.data == -1 and output.type.shape[i] == 1)
):
expand_axes.append(i)
else:
new_output_shape.append(dim_length)
if squeeze_axes or expand_axes:
new_out = inp.squeeze(squeeze_axes)
if new_output_shape:
new_out = new_out.reshape(new_output_shape)
copy_stack_trace(output, new_out)
new_out = expand_dims(new_out, expand_axes)
if not new_output_shape:
# Eagerly merge consecutive squeeze and expand_dims
new_out = apply_local_dimshuffle_lift(fgraph, new_out)
if len(new_output_shape) != output.type.ndim:
inner = inp.reshape(new_output_shape)
copy_stack_trace(output, inner)
new_out = expand_dims(inner, expand_axes)
copy_stack_trace(output, new_out)
return [new_out]
@register_specialize
@node_rewriter([Reshape])
def local_fuse_squeeze_reshape(fgraph, node):
r"""If there is a squeeze right before a reshape, merge them.
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
"""
x, new_shape = node.inputs
if (
x.owner is not None
and isinstance(x.owner.op, DimShuffle)
and x.owner.op.is_squeeze
):
# A reshape can always subsume a squeeze.
x = x.owner.inputs[0]
return [x.reshape(new_shape)]
@register_specialize
@node_rewriter([DimShuffle])
def local_fuse_expand_dims_reshape(fgraph, node):
r"""If there is an expand_dims right after a reshape, merge them.
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
"""
if not node.op.is_expand_dims:
return None
reshaped_x = node.inputs[0]
if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)):
return None
if len(fgraph.clients[reshaped_x]) > 1:
# The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
# Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]`
return None
x, new_shape = reshaped_x.owner.inputs
# Add expand_dims to shape
new_shape = list(_unpack_shape_vector(new_shape))
for i in node.op.augment:
new_shape.insert(i, 1)
new_reshaped_x = x.reshape(new_shape)
copy_stack_trace(node.outputs[0], new_reshaped_x)
return [new_reshaped_x]
@register_canonicalize
@register_specialize
@node_rewriter([Reshape])
......
......@@ -332,7 +332,6 @@ class TestLocalCanonicalizeAlloc:
mode = rewrite_mode.including(
"local_dimshuffle_lift",
"local_useless_dimshuffle_in_reshape",
"local_alloc_sink_dimshuffle",
)
f = function([x], [y], mode=mode)
......
......@@ -56,7 +56,10 @@ from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
from pytensor.tensor.rewriting.shape import (
local_fuse_squeeze_reshape,
local_useless_expand_dims_in_reshape,
)
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import (
TensorType,
......@@ -182,7 +185,7 @@ class TestDimshuffleLift:
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
def test_local_useless_dimshuffle_in_reshape():
def test_local_useless_expand_dims_in_reshape():
vec = TensorType(dtype="float64", shape=(None,))("vector")
mat = TensorType(dtype="float64", shape=(None, None))("mat")
row = TensorType(dtype="float64", shape=(1, None))("row")
......@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape():
clone=False,
)
assert len(g.apply_nodes) == 4 * 3
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape = out2in(
local_useless_expand_dims_in_reshape,
# Useless squeeze in reshape is not a canonicalization anymore
local_fuse_squeeze_reshape,
)
useless_dimshuffle_in_reshape.rewrite(g)
assert equal_computations(
g.outputs,
......@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape():
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
# Check that the rewrite does not mess meaningful transpositions before the reshape
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
assert len(h.apply_nodes) == 3
useless_dimshuffle_in_reshape.rewrite(h)
assert equal_computations(
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
)
assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)])
class TestFusion:
......
......@@ -6,7 +6,7 @@ import pytest
import pytensor.tensor as pt
from pytensor import shared
from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
......@@ -426,6 +426,60 @@ class TestLocalReshapeToDimshuffle:
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
def test_expand_dims(self):
x = pt.scalar()
# This reshape does an implicit expand_dims
out = x.reshape((1, -1))
assert isinstance(out.owner.op, Reshape)
new_out = rewrite_graph(out, include=("canonicalize",))
assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))])
def test_squeeze_of_alloc(self):
# This shows up in the graph of repeat
x = pt.vector("x", shape=(9,))
bcast_x = pt.alloc(x, 1, 12, x.shape[0])
# This reshape does an implicit squeeze
out = bcast_x.reshape((12, x.shape[0]))
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
def test_expand_dims_squeeze_reshape_fusion():
x = pt.tensor("x", shape=(1, 9))
reshape_x = x.squeeze(0).reshape((3, 3))[..., None]
assert isinstance(reshape_x.owner.op, DimShuffle)
assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape)
assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle)
out = rewrite_graph(reshape_x, include=("specialize",))
# In this case we cannot get rid of the reshape, squeeze or expand_dims,
# so we fuse them all in one reshape
assert equal_computations([out], [x.reshape((3, 3, 1))])
def test_implicit_broadcasting_via_repeat():
x = pt.vector("x", shape=(3,), dtype=int)
y = pt.vector("y", shape=(9,), dtype=int)
out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1)
# There are two Reshapes in the graph
assert isinstance(out.owner.inputs[0].owner.op, Reshape)
assert isinstance(out.owner.inputs[1].owner.op, Reshape)
new_out = rewrite_graph(out, include=("canonicalize", "specialize"))
assert equal_computations([new_out], [x[None] <= y[:, None]])
no_rewrite_mode = Mode(linker="py", optimizer=None)
x_test = np.arange(3) + 1
y_test = np.arange(9)
np.testing.assert_allclose(
new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
)
def test_local_reshape_lift():
x = tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论