提交 4b9163bc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle vectorization of Alloc nodes

上级 ba5336f6
......@@ -1802,6 +1802,40 @@ def _get_vector_length_Alloc(var_inst, var):
raise ValueError(f"Length of {var} cannot be determined")
@_vectorize_node.register(Alloc)
def vectorize_alloc(op, node, val, *shape):
old_val, *old_shape = node.inputs
[old_alloc] = node.outputs
assert len(shape) == len(old_shape), (
"Number of shape entries can't change in vectorize_alloc"
)
if not all(all(s.broadcastable) for s in shape):
# May imply a non-square Alloc
return vectorize_node_fallback(op, node, val, *shape)
val_batch_ndim = val.ndim - old_val.ndim
shape_batch_ndim = max((s.ndim for s in shape), default=0)
# Add implicit core dims that alloc prepends (alloc aligns val to the right)
n_implicit_core_dims = len(old_shape) - old_val.ndim
if n_implicit_core_dims > 0:
val = expand_dims(
val, list(range(val_batch_ndim, val_batch_ndim + n_implicit_core_dims))
)
new_alloc = alloc(
val,
*val.shape[:val_batch_ndim],
*(s.squeeze() for s in shape),
)
# Expand leading batch dims implied by the shape entries (if any)
new_alloc = atleast_Nd(new_alloc, n=shape_batch_ndim + old_alloc.ndim)
return [new_alloc]
def full(shape, fill_value, dtype=None):
"""Return a new array of given shape and type, filled with `fill_value`.
......
......@@ -5,7 +5,14 @@ from pytensor.graph.replace import vectorize_graph
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.graph.traversal import apply_ancestors
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
alloc,
expand_dims,
shape_padleft,
)
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import (
......@@ -89,6 +96,7 @@ optdb.register(
blockwise_of(
Dot
| Alloc
| AllocEmpty
| ARange
| Subtensor
| AdvancedSubtensor
......@@ -106,7 +114,7 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_alloc(fgraph, node):
def local_blockwise_alloc_inputs(fgraph, node):
"""Push Allocs from the inputs to the output of Blockwise Ops.
BOp = Blockwise(Op, signature="(x),(x)->(x)")
......@@ -218,6 +226,34 @@ def local_blockwise_alloc(fgraph, node):
return new_outs
@register_canonicalize
@register_specialize
@node_rewriter([blockwise_of(Alloc)])
def local_blockwise_alloc(fgraph, node):
val, *shape = node.inputs
if not all(all(s.broadcastable) for s in shape):
# May imply a non-square Alloc
return None
batch_ndim = node.op.batch_ndim(node)
# Add implicit core dims that alloc prepends (alloc aligns val to the right)
n_implicit_core_dims = node.outputs[0].ndim - val.ndim
if n_implicit_core_dims > 0:
val = expand_dims(
val, list(range(batch_ndim, batch_ndim + n_implicit_core_dims))
)
new_alloc = alloc(
val,
*val.shape[:batch_ndim],
*(s.squeeze() for s in shape),
)
copy_stack_trace(node.outputs[0], new_alloc)
return [new_alloc]
@register_specialize
@node_rewriter([blockwise_of(Reshape)])
def local_blockwise_reshape(fgraph, node):
......
from functools import partial
import numpy as np
import pytest
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.graph.traversal import apply_ancestors
from pytensor.scalar import log as scalar_log
from pytensor.tensor import add, alloc, matrix, tensor, tensor3
from pytensor.tensor import add, alloc, iscalar, matrix, scalar, tensor, tensor3
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv
......@@ -46,7 +48,7 @@ def test_useless_unbatched_blockwise():
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
def test_blockwise_alloc():
def test_local_blockwise_alloc_inputs():
rewrite = partial(
rewrite_graph,
include=("ShapeOpt", "specialize"),
......@@ -126,6 +128,42 @@ def test_blockwise_alloc():
assert equal([rewrite(out)], [expected_out])
@pytest.mark.parametrize("implicit_dims", [True, False])
def test_local_blockwise_alloc(implicit_dims):
"""Test that Blockwise(Alloc) is rewritten to a plain Alloc."""
x = scalar("x")
n = iscalar("n")
if implicit_dims:
out = alloc(x, n)
else:
out = alloc(x[None], n)
# Vectorize with a batch shape that is itself an Alloc.
# This creates Blockwise(Alloc) because the shape is non-broadcastable.
# Other rewrites lift the Alloc above the Blockwise, then
# local_blockwise_alloc simplifies the remaining Blockwise(Alloc).
vect_x = tensor("vect_x", shape=(5,))
vect_out = vectorize_graph(out, {x: vect_x, n: alloc(n, 5)})
assert isinstance(vect_out.owner.op, Blockwise)
rewritten_vect_out = rewrite_graph(
vect_out, include=("canonicalize", "specialize"), clone=True
)
assert not any(
isinstance(node.op, Blockwise) for node in apply_ancestors([rewritten_vect_out])
)
n_val = np.int64(3)
vect_x_test = np.random.normal(size=(5,)).astype(config.floatX)
no_rewrites = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
vect_out.eval({"vect_x": vect_x_test, "n": n_val}, mode=no_rewrites),
rewritten_vect_out.eval(
{"vect_x": vect_x_test, "n": n_val}, on_unused_input="ignore"
),
)
def test_blockwise_reshape():
x = tensor("x", shape=(None, None, None))
y = x.reshape([x.shape[0] * x.shape[1], -1])
......
......@@ -17,7 +17,8 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, vectorize_graph
from pytensor.graph.traversal import apply_ancestors
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as
......@@ -4577,6 +4578,25 @@ def test_vectorize_join(axis, broadcasting_y):
)
@pytest.mark.parametrize("implicit_dims", [True, False])
def test_vectorize_alloc(implicit_dims):
x = scalar("x")
if implicit_dims:
out = alloc(x, 3, 5)
else:
out = alloc(x[None, None], 3, 5)
vect_x = tensor("vect_x", shape=(7,))
vect_out = vectorize_graph(out, {x: vect_x})
assert not any(
isinstance(node.op, Blockwise) for node in apply_ancestors([vect_out])
)
x_test = np.random.normal(size=(7,)).astype(config.floatX)
expected = np.broadcast_to(x_test[:, None, None], (7, 3, 5))
np.testing.assert_allclose(vect_out.eval({vect_x: x_test}), expected)
def test_where():
a = np.arange(10)
cond = a < 5
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论