提交 dbf5f38e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor reshape + dimshuffle rewrites

上级 02545ed5
......@@ -2800,16 +2800,6 @@ def _check_chain(r, chain):
return r is not None
def check_chain(r, *chain):
"""
WRITEME
"""
if isinstance(r, Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_node_rewriter(
fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable
) -> Variable:
......
......@@ -12,16 +12,17 @@ from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import (
GraphRewriter,
check_chain,
copy_stack_trace,
node_rewriter,
)
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import (
MakeVector,
as_tensor_variable,
cast,
constant,
expand_dims,
get_scalar_constant_value,
register_infer_shape,
stack,
......@@ -47,6 +48,7 @@ from pytensor.tensor.shape import (
from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.variable import TensorVariable
class ShapeFeature(Feature):
......@@ -755,6 +757,42 @@ pytensor.compile.mode.optdb.register(
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
@register_canonicalize
@node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
dimshuffled_x, new_shape = node.inputs
if not (
dimshuffled_x.owner is not None
and isinstance(dimshuffled_x.owner.op, DimShuffle)
):
return False
[inp] = dimshuffled_x.owner.inputs
new_order = dimshuffled_x.owner.op.new_order
new_order_of_nonbroadcast = []
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
if s != 1:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in range(len(new_order_of_nonbroadcast) - 1)
)
if no_change_in_order:
ret = inp.reshape(new_shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Reshape])
......@@ -763,30 +801,89 @@ def local_reshape_chain(fgraph, node):
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
"""
if not check_chain(node, Reshape, Reshape):
inner_reshape, final_shape = node.inputs
if not (inner_reshape.owner and isinstance(inner_reshape.owner.op, Reshape)):
return None
x, _ = inner_reshape.owner.inputs
new_reshape = node.op(x, final_shape)
copy_stack_trace(node.outputs, new_reshape)
return [new_reshape]
def _is_shape_i_of_x(
var: TensorVariable,
x: TensorVariable,
i: int,
shape_feature: ShapeFeature | None = None,
) -> bool:
if var.type.ndim != 0:
return False
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True)
if s1 == 1 or s2 == 1
):
return [rval]
constant_var = get_scalar_constant_value(
var,
only_process_constants=False,
# Don't go through Elemwise to keep things fast
elemwise=False,
raise_not_constant=False,
)
# Check var is a constant expression with the same value as x.type.shape[i]
if constant_var == x.type.shape[i]:
return True
# Match shape_of[x][i] or its constant equivalent
if shape_feature is not None:
i_shape_of_x = shape_feature.get_shape(x, i)
if i_shape_of_x == var or (
isinstance(i_shape_of_x, Constant) and (i_shape_of_x.data == constant_var)
):
return True
if var.owner is None:
# No more constant possibilities
return False
# Match Shape_i{i}(x)
if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
# Match Subtensor((ScalarType,))(Shape(input), i)
if isinstance(var.owner.op, Subtensor):
return (
# Check we have integer indexing operation
# (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1
and isinstance(var.owner.op.idx_list[0], ScalarType)
# Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape)
and var.owner.inputs[0].owner.inputs[0] == x
# Check that index == i
and (
get_scalar_constant_value(var.owner.inputs[1], raise_not_constant=False)
== i
)
)
return False
def _unpack_shape_vector(shape: TensorVariable) -> tuple[TensorVariable, ...]:
"""Return the elements of a symbolic vector representing a shape.
Handles the most common constant vector or make_vector cases.
Returns tuple(shape) as fallback.
"""
if isinstance(shape, Constant):
return tuple(as_tensor_variable(dim, ndim=0) for dim in shape.data)
elif shape.owner and isinstance(shape.owner.op, MakeVector):
return tuple(shape.owner.inputs)
else:
return tuple(shape)
@register_useless("shape_unsafe")
......@@ -821,87 +918,30 @@ def local_useless_reshape(fgraph, node):
if shape_input == inp:
return [inp]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if isinstance(output_shape, Constant) or (
output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
):
if isinstance(output_shape, Constant):
output_shape_is = [
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
]
else:
output_shape_is = output_shape.owner.inputs
shape_feature = getattr(fgraph, "shape_feature", None)
nb_m1 = 0
shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Shape_i)
and outshp_i.owner.op.i == dim
and outshp_i.owner.inputs[0] == inp
):
shape_match[dim] = True
continue
shape_feature = getattr(fgraph, "shape_feature", None)
# Match Shape(input)[dim]
if (
outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2
and get_scalar_constant_value(
outshp_i.owner.inputs[1], raise_not_constant=False
)
== dim
):
subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == inp:
shape_match[dim] = True
continue
# Match constant if input.type.shape[dim] == constant
cst_outshp_i = get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True
continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
# or cases where all but one dimension are provably preserved
output_shape_is = _unpack_shape_vector(output_shape)
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
get_scalar_constant_value(
inpshp_i, only_process_constants=True, raise_not_constant=False
)
== get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
):
shape_match[dim] = True
continue
nb_m1 = 0
shape_match = [False] * inp.type.ndim
for dim in range(inp.type.ndim):
outshp_i = output_shape_is[dim]
if _is_shape_i_of_x(outshp_i, inp, dim, shape_feature=shape_feature):
shape_match[dim] = True
elif isinstance(outshp_i, Constant) and outshp_i.data == -1:
shape_match[dim] = True
nb_m1 += 1
if nb_m1 <= 1 and all(shape_match):
return [inp]
if nb_m1 <= 1 and all(shape_match):
return [inp]
# There is one missing match, but all other dimensions match
if (nb_m1 == 0) and (shape_match.count(False) == 1):
return [inp]
# There is one missing match, but all other dimensions match
if (nb_m1 == 0) and (shape_match.count(False) == 1):
return [inp]
return False
return False
@register_canonicalize
......@@ -915,39 +955,26 @@ def local_reshape_to_dimshuffle(fgraph, node):
For example:
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
"""
op = node.op
inp, output_shape = node.inputs
[output] = node.outputs
dimshuffle_new_order = []
unpacked_shape = _unpack_shape_vector(output_shape)
expand_axes = []
new_output_shape = []
index = 0 # index over the output of the new reshape
for i in range(output.ndim):
# Since output_shape is a symbolic vector, we trust get_scalar_constant_value
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = get_scalar_constant_value(
output_shape[i],
only_process_constants=False,
elemwise=False,
raise_not_constant=False,
)
if dim == 1:
dimshuffle_new_order.append("x")
for i, dim in enumerate(unpacked_shape):
if isinstance(dim, Constant) and dim.data == 1:
expand_axes.append(i)
else:
dimshuffle_new_order.append(index)
new_output_shape.append(dim)
index = index + 1
if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
if len(new_output_shape) != output.type.ndim:
inner = inp.reshape(new_output_shape)
copy_stack_trace(output, inner)
new_node = [inner.dimshuffle(dimshuffle_new_order)]
copy_stack_trace(output, new_node)
return new_node
new_out = expand_dims(inner, expand_axes)
copy_stack_trace(output, new_out)
return [new_out]
@register_canonicalize
......@@ -1187,44 +1214,6 @@ def local_track_shape_i(fgraph, node):
return [shape_feature.shape_of[replacement][node.op.i]]
@register_canonicalize
@node_rewriter([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if not isinstance(op, Reshape):
return False
if not (
node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, DimShuffle)
):
return False
new_order = node.inputs[0].owner.op.new_order
inp = node.inputs[0].owner.inputs[0]
new_order_of_nonbroadcast = []
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
if s != 1:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in range(len(new_order_of_nonbroadcast) - 1)
)
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(inp, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_useless
@register_canonicalize
@register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论