提交 65b96c1c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Canonicalize squeeze out of reshape and specialize back

上级 dbf5f38e
...@@ -36,6 +36,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -36,6 +36,7 @@ from pytensor.tensor.rewriting.basic import (
register_useless, register_useless,
topo_constant_folding, topo_constant_folding,
) )
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Reshape, Reshape,
Shape, Shape,
...@@ -757,40 +758,36 @@ pytensor.compile.mode.optdb.register( ...@@ -757,40 +758,36 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node): def local_useless_expand_dims_in_reshape(fgraph, node):
""" """
Removes useless DimShuffle operation inside Reshape: Removes useless expand_dims `DimShuffle` operations inside Reshape:
reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
Implicit (and useless) squeezes are kept in the graph, as they are
part of the canonical form of the graph.
""" """
dimshuffled_x, new_shape = node.inputs expanded_x, new_shape = node.inputs
if not ( if not (
dimshuffled_x.owner is not None expanded_x.owner is not None
and isinstance(dimshuffled_x.owner.op, DimShuffle) and isinstance(expanded_x.owner.op, DimShuffle)
and expanded_x.owner.op.augment
): ):
return False return False
[inp] = dimshuffled_x.owner.inputs [x] = expanded_x.owner.inputs
new_order = dimshuffled_x.owner.op.new_order
new_order_of_nonbroadcast = [] new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x")
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): if new_order != tuple(range(x.type.ndim)):
if s != 1: x = x.dimshuffle(new_order)
new_order_of_nonbroadcast.append(i)
no_change_in_order = all( new_reshaped_x = x.reshape(new_shape)
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] copy_stack_trace(node.outputs[0], new_reshaped_x)
for i in range(len(new_order_of_nonbroadcast) - 1) return [new_reshaped_x]
)
if no_change_in_order:
ret = inp.reshape(new_shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
...@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node): ...@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node):
shape_feature = getattr(fgraph, "shape_feature", None) shape_feature = getattr(fgraph, "shape_feature", None)
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1 # Match case where at least (n-1) entries correspond to the original shape:
# or cases where all but one dimension are provably preserved # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]])
# Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape.
output_shape_is = _unpack_shape_vector(output_shape) output_shape_is = _unpack_shape_vector(output_shape)
nb_m1 = 0 nb_m1 = 0
shape_match = [False] * inp.type.ndim shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim): for dim in range(inp.type.ndim):
...@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node): ...@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node):
nb_m1 += 1 nb_m1 += 1
if nb_m1 <= 1 and all(shape_match): if nb_m1 <= 1 and all(shape_match):
return [inp] return [inp] # This is provably correct
# There is one missing match, but all other dimensions match # There is one missing match, but all other dimensions match
# Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y)
if (nb_m1 == 0) and (shape_match.count(False) == 1): if (nb_m1 == 0) and (shape_match.count(False) == 1):
return [inp] return [inp] # This could mask a shape error
return False return False
@register_canonicalize @register_canonicalize("shape_unsafe")
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_reshape_to_dimshuffle(fgraph, node): def local_reshape_to_dimshuffle(fgraph, node):
r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s. r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions.
It's always valid to squeeze an input before doing the same reshape operation.
Equivalently, it's always valid to remove `1` entries from the reshape shape
and replace them by an expand_dims after the rewritten reshape operation.
The goal is to avoid using `Reshape` to add or remove broadcastable We chose to canonicalize the graph in this way as it allows isolating
dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can operations that are unique to the reshaping operation (mixing dimensions)
cancel out and/or be removed later on. from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims).
This can allow further simplifications by other rewrites that target
DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations.
For example: For example:
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
- reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
""" """
inp, output_shape = node.inputs inp, output_shape = node.inputs
[output] = node.outputs [output] = node.outputs
# Remove any broadcastable dimensions from the input
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
# Trivial case, all dimensions of input/output are known to be broadcastable:
# there's nothing to reshape
if all(inp.type.broadcastable) or all(output.type.broadcastable):
new_output_shape = []
expand_axes = tuple(range(output.type.ndim))
else:
unpacked_shape = _unpack_shape_vector(output_shape) unpacked_shape = _unpack_shape_vector(output_shape)
expand_axes = []
new_output_shape = [] new_output_shape = []
for i, dim in enumerate(unpacked_shape): expand_axes = []
if isinstance(dim, Constant) and dim.data == 1: for i, dim_length in enumerate(unpacked_shape):
if isinstance(dim_length, Constant) and (
dim_length.data == 1
# -1 can be an implicit expand_dims, but it's tricky to prove
# as we would need to check whether all other dimensions
# already explain the full size of the array.
# Example: np.zeros((2, 2, 2)).reshape((8, -1))
# We rely on the output static shape which will already have figured
# it out for some (but not all) cases
or (dim_length.data == -1 and output.type.shape[i] == 1)
):
expand_axes.append(i) expand_axes.append(i)
else: else:
new_output_shape.append(dim) new_output_shape.append(dim_length)
if squeeze_axes or expand_axes:
new_out = inp.squeeze(squeeze_axes)
if new_output_shape:
new_out = new_out.reshape(new_output_shape)
copy_stack_trace(output, new_out)
new_out = expand_dims(new_out, expand_axes)
if not new_output_shape:
# Eagerly merge consecutive squeeze and expand_dims
new_out = apply_local_dimshuffle_lift(fgraph, new_out)
if len(new_output_shape) != output.type.ndim:
inner = inp.reshape(new_output_shape)
copy_stack_trace(output, inner)
new_out = expand_dims(inner, expand_axes)
copy_stack_trace(output, new_out) copy_stack_trace(output, new_out)
return [new_out] return [new_out]
@register_specialize
@node_rewriter([Reshape])
def local_fuse_squeeze_reshape(fgraph, node):
r"""If there is a squeeze right before a reshape, merge them.
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
"""
x, new_shape = node.inputs
if (
x.owner is not None
and isinstance(x.owner.op, DimShuffle)
and x.owner.op.is_squeeze
):
# A reshape can always subsume a squeeze.
x = x.owner.inputs[0]
return [x.reshape(new_shape)]
@register_specialize
@node_rewriter([DimShuffle])
def local_fuse_expand_dims_reshape(fgraph, node):
r"""If there is an expand_dims right after a reshape, merge them.
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
"""
if not node.op.is_expand_dims:
return None
reshaped_x = node.inputs[0]
if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)):
return None
if len(fgraph.clients[reshaped_x]) > 1:
# The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
# Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]`
return None
x, new_shape = reshaped_x.owner.inputs
# Add expand_dims to shape
new_shape = list(_unpack_shape_vector(new_shape))
for i in node.op.augment:
new_shape.insert(i, 1)
new_reshaped_x = x.reshape(new_shape)
copy_stack_trace(node.outputs[0], new_reshaped_x)
return [new_reshaped_x]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Reshape]) @node_rewriter([Reshape])
......
...@@ -332,7 +332,6 @@ class TestLocalCanonicalizeAlloc: ...@@ -332,7 +332,6 @@ class TestLocalCanonicalizeAlloc:
mode = rewrite_mode.including( mode = rewrite_mode.including(
"local_dimshuffle_lift", "local_dimshuffle_lift",
"local_useless_dimshuffle_in_reshape",
"local_alloc_sink_dimshuffle", "local_alloc_sink_dimshuffle",
) )
f = function([x], [y], mode=mode) f = function([x], [y], mode=mode)
......
...@@ -56,7 +56,10 @@ from pytensor.tensor.math import pow as pt_pow ...@@ -56,7 +56,10 @@ from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape from pytensor.tensor.rewriting.shape import (
local_fuse_squeeze_reshape,
local_useless_expand_dims_in_reshape,
)
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
...@@ -182,7 +185,7 @@ class TestDimshuffleLift: ...@@ -182,7 +185,7 @@ class TestDimshuffleLift:
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner) assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
def test_local_useless_dimshuffle_in_reshape(): def test_local_useless_expand_dims_in_reshape():
vec = TensorType(dtype="float64", shape=(None,))("vector") vec = TensorType(dtype="float64", shape=(None,))("vector")
mat = TensorType(dtype="float64", shape=(None, None))("mat") mat = TensorType(dtype="float64", shape=(None, None))("mat")
row = TensorType(dtype="float64", shape=(1, None))("row") row = TensorType(dtype="float64", shape=(1, None))("row")
...@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape():
clone=False, clone=False,
) )
assert len(g.apply_nodes) == 4 * 3 assert len(g.apply_nodes) == 4 * 3
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) useless_dimshuffle_in_reshape = out2in(
local_useless_expand_dims_in_reshape,
# Useless squeeze in reshape is not a canonicalization anymore
local_fuse_squeeze_reshape,
)
useless_dimshuffle_in_reshape.rewrite(g) useless_dimshuffle_in_reshape.rewrite(g)
assert equal_computations( assert equal_computations(
g.outputs, g.outputs,
...@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape():
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order # Check that the rewrite does not mess meaningful transpositions before the reshape
# of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False) h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
assert len(h.apply_nodes) == 3 assert len(h.apply_nodes) == 3
useless_dimshuffle_in_reshape.rewrite(h) useless_dimshuffle_in_reshape.rewrite(h)
assert equal_computations( assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)])
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
)
class TestFusion: class TestFusion:
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import shared from pytensor import shared
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.basic import Apply, Variable, equal_computations
...@@ -426,6 +426,60 @@ class TestLocalReshapeToDimshuffle: ...@@ -426,6 +426,60 @@ class TestLocalReshapeToDimshuffle:
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
def test_expand_dims(self):
x = pt.scalar()
# This reshape does an implicit expand_dims
out = x.reshape((1, -1))
assert isinstance(out.owner.op, Reshape)
new_out = rewrite_graph(out, include=("canonicalize",))
assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))])
def test_squeeze_of_alloc(self):
# This shows up in the graph of repeat
x = pt.vector("x", shape=(9,))
bcast_x = pt.alloc(x, 1, 12, x.shape[0])
# This reshape does an implicit squeeze
out = bcast_x.reshape((12, x.shape[0]))
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
def test_expand_dims_squeeze_reshape_fusion():
x = pt.tensor("x", shape=(1, 9))
reshape_x = x.squeeze(0).reshape((3, 3))[..., None]
assert isinstance(reshape_x.owner.op, DimShuffle)
assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape)
assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle)
out = rewrite_graph(reshape_x, include=("specialize",))
# In this case we cannot get rid of the reshape, squeeze or expand_dims,
# so we fuse them all in one reshape
assert equal_computations([out], [x.reshape((3, 3, 1))])
def test_implicit_broadcasting_via_repeat():
x = pt.vector("x", shape=(3,), dtype=int)
y = pt.vector("y", shape=(9,), dtype=int)
out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1)
# There are two Reshapes in the graph
assert isinstance(out.owner.inputs[0].owner.op, Reshape)
assert isinstance(out.owner.inputs[1].owner.op, Reshape)
new_out = rewrite_graph(out, include=("canonicalize", "specialize"))
assert equal_computations([new_out], [x[None] <= y[:, None]])
no_rewrite_mode = Mode(linker="py", optimizer=None)
x_test = np.arange(3) + 1
y_test = np.arange(9)
np.testing.assert_allclose(
new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
)
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor4() x = tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论