提交 a62e785d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow Blockwise to create dummy core nodes with outer inputs, if these are unbatched

上级 efc9d693
import jax.numpy as jnp
from pytensor.graph import FunctionGraph
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blockwise import Blockwise
@jax_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
def jax_funcify_Blockwise(op: Blockwise, node, **kwargs):
signature = op.signature
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
tuple_core_fn = jax_funcify(core_fgraph)
if len(node.outputs) == 1:
def core_fn(*inputs):
return tuple_core_fn(*inputs)[0]
else:
core_fn = tuple_core_fn
core_node = op._create_dummy_core_node(
node.inputs, propagate_unbatched_core_inputs=True
)
core_fn = jax_funcify(core_node.op, node=core_node, **kwargs)
vect_fn = jnp.vectorize(core_fn, signature=signature)
......
......@@ -16,7 +16,7 @@ from pytensor.tensor import TensorVariable, get_vector_length
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
@numba_funcify.register
@numba_funcify.register(BlockwiseWithCoreShape)
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
[blockwise_node] = op.fgraph.apply_nodes
blockwise_op: Blockwise = blockwise_node.op
......@@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
core_node = blockwise_op._create_dummy_core_node(
cast(tuple[TensorVariable], blockwise_node.inputs)
cast(tuple[TensorVariable], node.inputs[:nin]),
propagate_unbatched_core_inputs=True,
)
core_op_fn = numba_funcify(
core_op,
......
from collections.abc import Callable, Sequence
from typing import Any, cast
from typing import Any, Literal, cast, overload
import numpy as np
from numpy import broadcast_shapes, empty
......@@ -32,6 +32,17 @@ from pytensor.tensor.utils import (
from pytensor.tensor.variable import TensorVariable
def _squeeze_left(x, stop_at_dim: int | None = None):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims = x.type.broadcastable
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
if stop_at_dim is not None:
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
if squeeze_ndim == 0:
return x
return x.squeeze(axis=tuple(range(squeeze_ndim)))
def _vectorize_node_perform(
core_node: Apply,
batch_bcast_patterns: Sequence[tuple[bool, ...]],
......@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
class Blockwise(COp):
"""Generalizes a core `Op` to work with batched dimensions.
TODO: Dispatch JAX (should be easy with the vectorize macro)
TODO: Dispatch Numba
TODO: C implementation?
TODO: Fuse Blockwise?
"""
......@@ -202,21 +211,52 @@ class Blockwise(COp):
super().__init__(**kwargs)
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = []
@overload
def _create_dummy_core_node(
self,
inputs: Sequence[TensorVariable],
*,
propagate_unbatched_core_inputs: bool = False,
return_dummy_inputs: Literal[False] = ...,
) -> Apply: ...
@overload
def _create_dummy_core_node(
self,
inputs: Sequence[TensorVariable],
*,
propagate_unbatched_core_inputs: bool = False,
return_dummy_inputs: Literal[True] = ...,
) -> tuple[Apply, list[TensorVariable]]: ...
def _create_dummy_core_node(
self,
inputs: Sequence[TensorVariable],
*,
propagate_unbatched_core_inputs: bool = False,
return_dummy_inputs: bool = False,
) -> Apply | tuple[Apply, list[TensorVariable]]:
core_inputs = []
core_dummy_inputs = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
if inp.type.ndim < len(sig):
raise ValueError(
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
)
# ndim_supp = 0 case
if not sig:
core_shape = ()
inp_ndim = inp.type.ndim
batch_ndim = inp_ndim - len(sig)
core_shape = inp.type.shape[batch_ndim:]
if propagate_unbatched_core_inputs and all(
inp.type.broadcastable[:batch_ndim]
):
core_inputs.append(_squeeze_left(inp, batch_ndim))
else:
core_shape = inp.type.shape[-len(sig) :]
core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape))
dummy_inp = tensor(dtype=inp.type.dtype, shape=core_shape)
core_inputs.append(dummy_inp)
core_dummy_inputs.append(dummy_inp)
core_node = self.core_op.make_node(*core_input_types)
core_node = self.core_op.make_node(*core_inputs)
if len(core_node.outputs) != len(self.outputs_sig):
raise ValueError(
......@@ -230,6 +270,9 @@ class Blockwise(COp):
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
)
if return_dummy_inputs:
return core_node, core_dummy_inputs
return core_node
def make_node(self, *inputs):
......@@ -298,11 +341,17 @@ class Blockwise(COp):
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
# Try to extract the core shapes from the core_op
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
if core_op_infer_shape is not None:
dummy_core_node = self._create_dummy_core_node(node.inputs)
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
def extract_core_shape_from_infer_shape():
# Try to extract the core shapes from the core_op
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
if core_op_infer_shape is None:
return [[None] * out.ndim for out in node.outputs]
dummy_core_node, dummy_core_inputs = self._create_dummy_core_node(
node.inputs,
return_dummy_inputs=True,
propagate_unbatched_core_inputs=True,
)
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
core_input_shapes = [
input_shape[batch_ndims:] for input_shape in input_shapes
......@@ -311,6 +360,25 @@ class Blockwise(COp):
dummy_fgraph, dummy_core_node, core_input_shapes
)
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
if not dummy_core_inputs:
# All inputs are unbatched, so the core_shape can be used as is
return core_output_shapes
else:
set_dummy_core_inputs = set(dummy_core_inputs)
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
for core_out_shape in safe_core_output_shapes:
for o, core_out_dim in enumerate(core_out_shape):
if set_dummy_core_inputs & set(
explicit_graph_inputs([core_out_dim])
):
core_out_shape[o] = None
return safe_core_output_shapes
safe_core_out_shape = None
out_shapes = []
for o, (output, sig) in enumerate(
zip(node.outputs, self.outputs_sig, strict=True)
......@@ -321,19 +389,15 @@ class Blockwise(COp):
if dim_name in core_dims:
core_out_shape.append(core_dims[dim_name])
else:
if core_op_infer_shape is not None:
# If the input values are needed to compute the dimension length, we can't use the infer_shape
# of the core_node as the value is not constant across batch dims of the Blockwise
core_out_dim = core_output_shapes[o][i]
if not (
set(dummy_core_inputs)
& set(explicit_graph_inputs([core_out_dim]))
):
core_out_shape.append(core_out_dim)
continue
# Fallback shape requires evaluating the Blockwise Op
core_out_shape.append(Shape_i(batch_ndims + i)(output))
if safe_core_out_shape is None:
# Extract the core shape from the core_op infer_shape on demand
# For many Ops we never need to do this, because all info is in their signature
safe_core_out_shape = extract_core_shape_from_infer_shape()
if (core_out_dim := safe_core_out_shape[o][i]) is not None:
core_out_shape.append(core_out_dim)
else:
# Fallback shape requires evaluating the Blockwise Op
core_out_shape.append(Shape_i(batch_ndims + i)(output))
out_shapes.append((*batch_shape, *core_out_shape))
return out_shapes
......@@ -448,7 +512,10 @@ class Blockwise(COp):
)
return core_func(*inputs)
else:
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
core_node = self._create_dummy_core_node(
cast(list[TensorVariable], node.inputs),
propagate_unbatched_core_inputs=True,
)
gufunc = _vectorize_node_perform(
core_node,
batch_bcast_patterns=batch_bcast_patterns,
......
......@@ -4,7 +4,7 @@ from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
......@@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
return local_useless_unbatched_blockwise.fn(fgraph, node)
def _squeeze_left(x, stop_at_dim: int | None = None):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims = x.type.broadcastable
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
if stop_at_dim is not None:
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
if squeeze_ndim == 0:
return x
return x.squeeze(axis=tuple(range(squeeze_ndim)))
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_alloc(fgraph, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论