提交 2e2c871e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite for Blockwise with Alloc inputs

Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
上级 fe06ee32
...@@ -1777,6 +1777,7 @@ def equal_computations( ...@@ -1777,6 +1777,7 @@ def equal_computations(
ys: list[Union[np.ndarray, Variable]], ys: list[Union[np.ndarray, Variable]],
in_xs: Optional[list[Variable]] = None, in_xs: Optional[list[Variable]] = None,
in_ys: Optional[list[Variable]] = None, in_ys: Optional[list[Variable]] = None,
strict_dtype=True,
) -> bool: ) -> bool:
"""Checks if PyTensor graphs represent the same computations. """Checks if PyTensor graphs represent the same computations.
...@@ -1908,7 +1909,10 @@ def equal_computations( ...@@ -1908,7 +1909,10 @@ def equal_computations(
if dx != dy: if dx != dy:
if isinstance(dx, Constant) and isinstance(dy, Constant): if isinstance(dx, Constant) and isinstance(dy, Constant):
if not dx.equals(dy): if not dx.equals(dy):
return False if strict_dtype:
return False
elif not np.array_equal(dx.data, dy.data):
return False
else: else:
return False return False
......
...@@ -42,6 +42,7 @@ from pytensor.tensor import ( ...@@ -42,6 +42,7 @@ from pytensor.tensor import (
as_tensor_variable, as_tensor_variable,
get_vector_length, get_vector_length,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
...@@ -1658,16 +1659,22 @@ class Alloc(COp): ...@@ -1658,16 +1659,22 @@ class Alloc(COp):
if not clients: if not clients:
return False return False
for client in clients: for client, idx in clients:
if client[0] == "output": if client == "output":
# If the output is a constant, it will have to be deepcopied # If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold. # each time the function is called. So we do not fold.
return False return False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif isinstance(client.op, Elemwise):
return None
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
return None
elif ( elif (
# The following ops work inplace of their input id 0. # The following ops work inplace of their input id 0.
client[1] == 0 idx == 0
and isinstance( and isinstance(
client[0].op, client.op,
( (
# Ops that will work inplace on the Alloc. So if they # Ops that will work inplace on the Alloc. So if they
# get constant_folded, they would copy the # get constant_folded, they would copy the
......
from typing import Optional
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
...@@ -80,3 +82,120 @@ def local_eager_useless_unbatched_blockwise(fgraph, node): ...@@ -80,3 +82,120 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
), ),
): ):
return local_useless_unbatched_blockwise.fn(fgraph, node) return local_useless_unbatched_blockwise.fn(fgraph, node)
def _squeeze_left(x, stop_at_dim: Optional[int] = None):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims = x.type.broadcastable
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
if stop_at_dim is not None:
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
if squeeze_ndim == 0:
return x
return x.squeeze(axis=tuple(range(squeeze_ndim)))
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_alloc(fgraph, node):
"""Push Allocs from the inputs to the output of Blockwise Ops.
BOp = Blockwise(Op, signature="(x),(x)->(x)")
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5)
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5)
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
"""
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
return None
op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node)
if not batch_ndim:
return None
new_inputs = []
batch_shapes = []
can_push_any_alloc = False
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
if inp.owner and isinstance(inp.owner.op, Alloc):
# Push batch dims from Alloc
value, *shape = inp.owner.inputs
# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
if (
((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]
) != inp.type.broadcastable[batch_ndim:]:
# We still need an Alloc for the core dims
core_shape = shape[batch_ndim:]
# And the batch dims of the squeezed value
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape)
batch_shape = [
1 if broadcastable else dim
for broadcastable, dim in zip(
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
)
]
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
if squeezed_value.type.broadcastable == inp.type.broadcastable:
# We can't change anything about this Alloc input
new_inputs.append(inp)
continue
# We can push batch dims of this Alloc input
batch_shapes.append(
tuple(
1 if broadcastable else dim
for broadcastable, dim in zip(
inp.type.broadcastable, shape[:batch_ndim]
)
)
)
new_inputs.append(squeezed_value)
can_push_any_alloc = True
else:
# Nothing to do with this input other than removing dummy batch dims
new_inputs.append(_squeeze_left(inp, batch_ndim))
if not can_push_any_alloc:
return None
new_outs = node.op.make_node(*new_inputs).outputs
new_out_type = new_outs[0].type
old_out_type = node.outputs[0].type
if new_out_type.broadcastable != old_out_type.broadcastable:
# An Alloc is still needed to broadcast the new output to the original shape
# We pick the most parsimonious batch dim from the pushed Alloc
missing_ndim = old_out_type.ndim - new_out_type.ndim
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
for batch_dim in batch_dims:
if batch_dim == 1:
continue
if isinstance(batch_dim, Constant):
# Give preference to Constants
batch_shape[i] = batch_dim
break
elif old_out_type.broadcastable[i]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape[i] = batch_dim
copy_stack_trace(node.outputs, new_outs)
new_outs = [
alloc(
new_out,
*batch_shape,
*tuple(new_out.shape)[batch_ndim - missing_ndim :],
)
for new_out in new_outs
]
assert new_outs[0].type.broadcastable == old_out_type.broadcastable
copy_stack_trace(node.outputs, new_outs)
return new_outs
from functools import partial
from pytensor import function from pytensor import function
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph, rewrite_graph
from pytensor.graph.basic import equal_computations
from pytensor.scalar import log as scalar_log from pytensor.scalar import log as scalar_log
from pytensor.tensor import matrix, tensor3 from pytensor.tensor import add, alloc, matrix, tensor, tensor3
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv from pytensor.tensor.nlinalg import MatrixPinv
...@@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise(): ...@@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise():
fn = function([x], out, mode="FAST_COMPILE") fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv) assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
def test_blockwise_alloc():
rewrite = partial(
rewrite_graph,
include=("ShapeOpt", "specialize"),
exclude=("local_useless_unbatched_blockwise",),
)
vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)")
# Depending on the rewrites the Alloc shape may be upcast to int64 or not
# We do not care about that for the purposes of this test
equal = partial(equal_computations, strict_dtype=False)
# Case where Alloc is not necessary
x = tensor("x", shape=(7, 5))
y = tensor("y", shape=(5,))
out = vector_add(x, alloc(y, 7, 5))
expected_out = vector_add(x, y)
assert equal([rewrite(out)], [expected_out])
# Cases where Alloc can be fully pushed
x = tensor("x", shape=(5,))
y = tensor("y", shape=(5,))
out = vector_add(x, alloc(y, 7, 5))
expected_out = alloc(vector_add(x, y), 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(1, 5))
y = tensor("y", shape=(5,))
out = vector_add(x, alloc(y, 7, 5))
expected_out = alloc(vector_add(x.squeeze(0), y), 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(7, 5))
y = tensor("y", shape=(7, 5))
out = vector_add(x, alloc(y, 3, 7, 5))
expected_out = alloc(vector_add(x, y), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(5,))
y = tensor("y", shape=(7, 1, 5))
out = vector_add(x, alloc(y, 7, 2, 5))
expected_out = alloc(vector_add(x, y), 7, 2, 5)
assert equal([rewrite(out)], [expected_out])
# Case where Alloc can be partially pushed
x = tensor("x", shape=(5,))
y = tensor("y", shape=())
out = vector_add(x, alloc(y, 7, 5))
expected_out = alloc(vector_add(x, alloc(y, 5)), 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(5,))
y = tensor("y", shape=(7, 1, 1))
out = vector_add(x, alloc(y, 7, 2, 5))
expected_out = alloc(vector_add(x, alloc(y, 7, 1, 5)), 7, 2, 5)
assert equal([rewrite(out)], [expected_out], strict_dtype=False)
# Cases involving multiple Allocs being pushed
x = tensor("x", shape=())
y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
x = tensor("x", shape=(5,))
y = tensor("y", shape=())
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
expected_out = alloc(vector_add(x, alloc(y, 5)), 3, 7, 5)
assert equal([rewrite(out)], [expected_out])
# Case where Alloc cannot be pushed
x = tensor("x", shape=(5,))
y = tensor("y", shape=(1,))
out = vector_add(x, alloc(y, 5))
expected_out = out
assert equal([rewrite(out)], [expected_out])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论