提交 2e2c871e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite for Blockwise with Alloc inputs

Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
上级 fe06ee32
......@@ -1777,6 +1777,7 @@ def equal_computations(
ys: list[Union[np.ndarray, Variable]],
in_xs: Optional[list[Variable]] = None,
in_ys: Optional[list[Variable]] = None,
strict_dtype=True,
) -> bool:
"""Checks if PyTensor graphs represent the same computations.
......@@ -1908,6 +1909,9 @@ def equal_computations(
if dx != dy:
if isinstance(dx, Constant) and isinstance(dy, Constant):
if not dx.equals(dy):
if strict_dtype:
return False
elif not np.array_equal(dx.data, dy.data):
return False
else:
return False
......
......@@ -42,6 +42,7 @@ from pytensor.tensor import (
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import (
......@@ -1658,16 +1659,22 @@ class Alloc(COp):
if not clients:
return False
for client in clients:
if client[0] == "output":
for client, idx in clients:
if client == "output":
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif isinstance(client.op, Elemwise):
return None
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
return None
elif (
# The following ops work inplace of their input id 0.
client[1] == 0
idx == 0
and isinstance(
client[0].op,
client.op,
(
# Ops that will work inplace on the Alloc. So if they
# get constant_folded, they would copy the
......
from typing import Optional
from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, shape_padleft
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import (
......@@ -80,3 +82,120 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
),
):
return local_useless_unbatched_blockwise.fn(fgraph, node)
def _squeeze_left(x, stop_at_dim: Optional[int] = None):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims = x.type.broadcastable
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
if stop_at_dim is not None:
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
if squeeze_ndim == 0:
return x
return x.squeeze(axis=tuple(range(squeeze_ndim)))
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_alloc(fgraph, node):
"""Push Allocs from the inputs to the output of Blockwise Ops.
BOp = Blockwise(Op, signature="(x),(x)->(x)")
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5)
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5)
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
"""
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
return None
op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node)
if not batch_ndim:
return None
new_inputs = []
batch_shapes = []
can_push_any_alloc = False
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
if inp.owner and isinstance(inp.owner.op, Alloc):
# Push batch dims from Alloc
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
if (
((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]
) != inp.type.broadcastable[batch_ndim:]:
# We still need an Alloc for the core dims
core_shape = shape[batch_ndim:]
# And the batch dims of the squeezed value
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape)
batch_shape = [
1 if broadcastable else dim
for broadcastable, dim in zip(
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
)
]
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
if squeezed_value.type.broadcastable == inp.type.broadcastable:
# We can't change anything about this Alloc input
new_inputs.append(inp)
continue
# We can push batch dims of this Alloc input
batch_shapes.append(
tuple(
1 if broadcastable else dim
for broadcastable, dim in zip(
inp.type.broadcastable, shape[:batch_ndim]
)
)
)
new_inputs.append(squeezed_value)
can_push_any_alloc = True
else:
# Nothing to do with this input other than removing dummy batch dims
new_inputs.append(_squeeze_left(inp, batch_ndim))
if not can_push_any_alloc:
return None
new_outs = node.op.make_node(*new_inputs).outputs
new_out_type = new_outs[0].type
old_out_type = node.outputs[0].type
if new_out_type.broadcastable != old_out_type.broadcastable:
# An Alloc is still needed to broadcast the new output to the original shape
# We pick the most parsimonious batch dim from the pushed Alloc
missing_ndim = old_out_type.ndim - new_out_type.ndim
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
for batch_dim in batch_dims:
if batch_dim == 1:
continue
if isinstance(batch_dim, Constant):
# Give preference to Constants
batch_shape[i] = batch_dim
break
elif old_out_type.broadcastable[i]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape[i] = batch_dim
copy_stack_trace(node.outputs, new_outs)
new_outs = [
alloc(
new_out,
*batch_shape,
*tuple(new_out.shape)[batch_ndim - missing_ndim :],
)
for new_out in new_outs
]
assert new_outs[0].type.broadcastable == old_out_type.broadcastable
copy_stack_trace(node.outputs, new_outs)
return new_outs
from functools import partial
from pytensor import function
from pytensor.graph import FunctionGraph
from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor.graph.basic import equal_computations
from pytensor.scalar import log as scalar_log
from pytensor.tensor import matrix, tensor3
from pytensor.tensor import add, alloc, matrix, tensor, tensor3
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv
......@@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise():
fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
def test_blockwise_alloc():
rewrite = partial(
rewrite_graph,
include=("ShapeOpt", "specialize"),
exclude=("local_useless_unbatched_blockwise",),
)
vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)")
# Depending on the rewrites the Alloc shape may be upcast to int64 or not
# We do not care about that for the purposes of this test
equal = partial(equal_computations, strict_dtype=False)
# Case where Alloc is not necessary
x = tensor("x", shape=(7, 5))
y = tensor("y", shape=(5,))
out = vector_add(x, alloc(y, 7, 5))
expected_out = vector_add(x, y)
assert equal([rewrite(out)], [expected_out])
# Cases where Alloc can be fully pushed
x = tensor("x", shape=(5,))
y = tensor("y", shape=(5,))
out = vector_add(x, alloc(y, 7, 5))
expected_out = alloc(vector_add(x, y), 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(1, 5))
y = tensor("y", shape=(5,))
out = vector_add(x, alloc(y, 7, 5))
expected_out = alloc(vector_add(x.squeeze(0), y), 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(7, 5))
y = tensor("y", shape=(7, 5))
out = vector_add(x, alloc(y, 3, 7, 5))
expected_out = alloc(vector_add(x, y), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(5,))
y = tensor("y", shape=(7, 1, 5))
out = vector_add(x, alloc(y, 7, 2, 5))
expected_out = alloc(vector_add(x, y), 7, 2, 5)
assert equal([rewrite(out)], [expected_out])
# Case where Alloc can be partially pushed
x = tensor("x", shape=(5,))
y = tensor("y", shape=())
out = vector_add(x, alloc(y, 7, 5))
expected_out = alloc(vector_add(x, alloc(y, 5)), 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(5,))
y = tensor("y", shape=(7, 1, 1))
out = vector_add(x, alloc(y, 7, 2, 5))
expected_out = alloc(vector_add(x, alloc(y, 7, 1, 5)), 7, 2, 5)
assert equal([rewrite(out)], [expected_out], strict_dtype=False)
# Cases involving multiple Allocs being pushed
x = tensor("x", shape=())
y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(5,))
y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(x, alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
# Case where Alloc cannot be pushed
x = tensor("x", shape=(5,))
y = tensor("y", shape=(1,))
out = vector_add(x, alloc(y, 5))
expected_out = out
assert equal([rewrite(out)], [expected_out])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论