提交 65b96c1c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Canonicalize squeeze out of reshape and specialize back

上级 dbf5f38e
...@@ -36,6 +36,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -36,6 +36,7 @@ from pytensor.tensor.rewriting.basic import (
register_useless, register_useless,
topo_constant_folding, topo_constant_folding,
) )
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Reshape, Reshape,
Shape, Shape,
...@@ -757,40 +758,36 @@ pytensor.compile.mode.optdb.register( ...@@ -757,40 +758,36 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node): def local_useless_expand_dims_in_reshape(fgraph, node):
""" """
Removes useless DimShuffle operation inside Reshape: Removes useless expand_dims `DimShuffle` operations inside Reshape:
reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
Implicit (and useless) squeezes are kept in the graph, as they are
part of the canonical form of the graph.
""" """
dimshuffled_x, new_shape = node.inputs expanded_x, new_shape = node.inputs
if not ( if not (
dimshuffled_x.owner is not None expanded_x.owner is not None
and isinstance(dimshuffled_x.owner.op, DimShuffle) and isinstance(expanded_x.owner.op, DimShuffle)
and expanded_x.owner.op.augment
): ):
return False return False
[inp] = dimshuffled_x.owner.inputs [x] = expanded_x.owner.inputs
new_order = dimshuffled_x.owner.op.new_order
new_order_of_nonbroadcast = [] new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x")
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): if new_order != tuple(range(x.type.ndim)):
if s != 1: x = x.dimshuffle(new_order)
new_order_of_nonbroadcast.append(i)
no_change_in_order = all( new_reshaped_x = x.reshape(new_shape)
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] copy_stack_trace(node.outputs[0], new_reshaped_x)
for i in range(len(new_order_of_nonbroadcast) - 1) return [new_reshaped_x]
)
if no_change_in_order:
ret = inp.reshape(new_shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
...@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node): ...@@ -920,10 +917,10 @@ def local_useless_reshape(fgraph, node):
shape_feature = getattr(fgraph, "shape_feature", None) shape_feature = getattr(fgraph, "shape_feature", None)
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1 # Match case where at least (n-1) entries correspond to the original shape:
# or cases where all but one dimension are provably preserved # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]])
# Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape.
output_shape_is = _unpack_shape_vector(output_shape) output_shape_is = _unpack_shape_vector(output_shape)
nb_m1 = 0 nb_m1 = 0
shape_match = [False] * inp.type.ndim shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim): for dim in range(inp.type.ndim):
...@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node): ...@@ -935,48 +932,136 @@ def local_useless_reshape(fgraph, node):
nb_m1 += 1 nb_m1 += 1
if nb_m1 <= 1 and all(shape_match): if nb_m1 <= 1 and all(shape_match):
return [inp] return [inp] # This is provably correct
# There is one missing match, but all other dimensions match # There is one missing match, but all other dimensions match
# Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y)
if (nb_m1 == 0) and (shape_match.count(False) == 1): if (nb_m1 == 0) and (shape_match.count(False) == 1):
return [inp] return [inp] # This could mask a shape error
return False return False
@register_canonicalize @register_canonicalize("shape_unsafe")
@node_rewriter([Reshape]) @node_rewriter([Reshape])
def local_reshape_to_dimshuffle(fgraph, node): def local_reshape_to_dimshuffle(fgraph, node):
r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s. r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions.
The goal is to avoid using `Reshape` to add or remove broadcastable It's always valid to squeeze an input before doing the same reshape operation.
dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can Equivalently, it's always valid to remove `1` entries from the reshape shape
cancel out and/or be removed later on. and replace them by an expand_dims after the rewritten reshape operation.
We chose to canonicalize the graph in this way as it allows isolating
operations that are unique to the reshaping operation (mixing dimensions)
from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims).
This can allow further simplifications by other rewrites that target
DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations.
For example: For example:
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
- reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
""" """
inp, output_shape = node.inputs inp, output_shape = node.inputs
[output] = node.outputs [output] = node.outputs
unpacked_shape = _unpack_shape_vector(output_shape) # Remove any broadcastable dimensions from the input
expand_axes = [] squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
new_output_shape = []
for i, dim in enumerate(unpacked_shape): # Trivial case, all dimensions of input/output are known to be broadcastable:
if isinstance(dim, Constant) and dim.data == 1: # there's nothing to reshape
expand_axes.append(i) if all(inp.type.broadcastable) or all(output.type.broadcastable):
else: new_output_shape = []
new_output_shape.append(dim) expand_axes = tuple(range(output.type.ndim))
else:
unpacked_shape = _unpack_shape_vector(output_shape)
new_output_shape = []
expand_axes = []
for i, dim_length in enumerate(unpacked_shape):
if isinstance(dim_length, Constant) and (
dim_length.data == 1
# -1 can be an implicit expand_dims, but it's tricky to prove
# as we would need to check whether all other dimensions
# already explain the full size of the array.
# Example: np.zeros((2, 2, 2)).reshape((8, -1))
# We rely on the output static shape which will already have figured
# it out for some (but not all) cases
or (dim_length.data == -1 and output.type.shape[i] == 1)
):
expand_axes.append(i)
else:
new_output_shape.append(dim_length)
if squeeze_axes or expand_axes:
new_out = inp.squeeze(squeeze_axes)
if new_output_shape:
new_out = new_out.reshape(new_output_shape)
copy_stack_trace(output, new_out)
new_out = expand_dims(new_out, expand_axes)
if not new_output_shape:
# Eagerly merge consecutive squeeze and expand_dims
new_out = apply_local_dimshuffle_lift(fgraph, new_out)
if len(new_output_shape) != output.type.ndim:
inner = inp.reshape(new_output_shape)
copy_stack_trace(output, inner)
new_out = expand_dims(inner, expand_axes)
copy_stack_trace(output, new_out) copy_stack_trace(output, new_out)
return [new_out] return [new_out]
@register_specialize
@node_rewriter([Reshape])
def local_fuse_squeeze_reshape(fgraph, node):
r"""If there is a squeeze right before a reshape, merge them.
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
"""
x, new_shape = node.inputs
if (
x.owner is not None
and isinstance(x.owner.op, DimShuffle)
and x.owner.op.is_squeeze
):
# A reshape can always subsume a squeeze.
x = x.owner.inputs[0]
return [x.reshape(new_shape)]
@register_specialize
@node_rewriter([DimShuffle])
def local_fuse_expand_dims_reshape(fgraph, node):
r"""If there is an expand_dims right after a reshape, merge them.
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
"""
if not node.op.is_expand_dims:
return None
reshaped_x = node.inputs[0]
if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)):
return None
if len(fgraph.clients[reshaped_x]) > 1:
# The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
# Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]`
return None
x, new_shape = reshaped_x.owner.inputs
# Add expand_dims to shape
new_shape = list(_unpack_shape_vector(new_shape))
for i in node.op.augment:
new_shape.insert(i, 1)
new_reshaped_x = x.reshape(new_shape)
copy_stack_trace(node.outputs[0], new_reshaped_x)
return [new_reshaped_x]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Reshape]) @node_rewriter([Reshape])
......
...@@ -332,7 +332,6 @@ class TestLocalCanonicalizeAlloc: ...@@ -332,7 +332,6 @@ class TestLocalCanonicalizeAlloc:
mode = rewrite_mode.including( mode = rewrite_mode.including(
"local_dimshuffle_lift", "local_dimshuffle_lift",
"local_useless_dimshuffle_in_reshape",
"local_alloc_sink_dimshuffle", "local_alloc_sink_dimshuffle",
) )
f = function([x], [y], mode=mode) f = function([x], [y], mode=mode)
......
...@@ -56,7 +56,10 @@ from pytensor.tensor.math import pow as pt_pow ...@@ -56,7 +56,10 @@ from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape from pytensor.tensor.rewriting.shape import (
local_fuse_squeeze_reshape,
local_useless_expand_dims_in_reshape,
)
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
...@@ -182,7 +185,7 @@ class TestDimshuffleLift: ...@@ -182,7 +185,7 @@ class TestDimshuffleLift:
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner) assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
def test_local_useless_dimshuffle_in_reshape(): def test_local_useless_expand_dims_in_reshape():
vec = TensorType(dtype="float64", shape=(None,))("vector") vec = TensorType(dtype="float64", shape=(None,))("vector")
mat = TensorType(dtype="float64", shape=(None, None))("mat") mat = TensorType(dtype="float64", shape=(None, None))("mat")
row = TensorType(dtype="float64", shape=(1, None))("row") row = TensorType(dtype="float64", shape=(1, None))("row")
...@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape():
clone=False, clone=False,
) )
assert len(g.apply_nodes) == 4 * 3 assert len(g.apply_nodes) == 4 * 3
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) useless_dimshuffle_in_reshape = out2in(
local_useless_expand_dims_in_reshape,
# Useless squeeze in reshape is not a canonicalization anymore
local_fuse_squeeze_reshape,
)
useless_dimshuffle_in_reshape.rewrite(g) useless_dimshuffle_in_reshape.rewrite(g)
assert equal_computations( assert equal_computations(
g.outputs, g.outputs,
...@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape():
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order # Check that the rewrite does not mess meaningful transpositions before the reshape
# of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False) h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
assert len(h.apply_nodes) == 3 assert len(h.apply_nodes) == 3
useless_dimshuffle_in_reshape.rewrite(h) useless_dimshuffle_in_reshape.rewrite(h)
assert equal_computations( assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)])
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
)
class TestFusion: class TestFusion:
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import shared from pytensor import shared
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.basic import Apply, Variable, equal_computations
...@@ -426,6 +426,60 @@ class TestLocalReshapeToDimshuffle: ...@@ -426,6 +426,60 @@ class TestLocalReshapeToDimshuffle:
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
def test_expand_dims(self):
x = pt.scalar()
# This reshape does an implicit expand_dims
out = x.reshape((1, -1))
assert isinstance(out.owner.op, Reshape)
new_out = rewrite_graph(out, include=("canonicalize",))
assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))])
def test_squeeze_of_alloc(self):
# This shows up in the graph of repeat
x = pt.vector("x", shape=(9,))
bcast_x = pt.alloc(x, 1, 12, x.shape[0])
# This reshape does an implicit squeeze
out = bcast_x.reshape((12, x.shape[0]))
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
def test_expand_dims_squeeze_reshape_fusion():
x = pt.tensor("x", shape=(1, 9))
reshape_x = x.squeeze(0).reshape((3, 3))[..., None]
assert isinstance(reshape_x.owner.op, DimShuffle)
assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape)
assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle)
out = rewrite_graph(reshape_x, include=("specialize",))
# In this case we cannot get rid of the reshape, squeeze or expand_dims,
# so we fuse them all in one reshape
assert equal_computations([out], [x.reshape((3, 3, 1))])
def test_implicit_broadcasting_via_repeat():
x = pt.vector("x", shape=(3,), dtype=int)
y = pt.vector("y", shape=(9,), dtype=int)
out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1)
# There are two Reshapes in the graph
assert isinstance(out.owner.inputs[0].owner.op, Reshape)
assert isinstance(out.owner.inputs[1].owner.op, Reshape)
new_out = rewrite_graph(out, include=("canonicalize", "specialize"))
assert equal_computations([new_out], [x[None] <= y[:, None]])
no_rewrite_mode = Mode(linker="py", optimizer=None)
x_test = np.arange(3) + 1
y_test = np.arange(9)
np.testing.assert_allclose(
new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
)
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor4() x = tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论