提交 a62e785d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow Blockwise to create dummy core nodes with outer inputs, if these are unbatched

上级 efc9d693
import jax.numpy as jnp import jax.numpy as jnp
from pytensor.graph import FunctionGraph
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
@jax_funcify.register(Blockwise) @jax_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): def jax_funcify_Blockwise(op: Blockwise, node, **kwargs):
signature = op.signature signature = op.signature
core_node = op._create_dummy_core_node(node.inputs) core_node = op._create_dummy_core_node(
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) node.inputs, propagate_unbatched_core_inputs=True
tuple_core_fn = jax_funcify(core_fgraph) )
core_fn = jax_funcify(core_node.op, node=core_node, **kwargs)
if len(node.outputs) == 1:
def core_fn(*inputs):
return tuple_core_fn(*inputs)[0]
else:
core_fn = tuple_core_fn
vect_fn = jnp.vectorize(core_fn, signature=signature) vect_fn = jnp.vectorize(core_fn, signature=signature)
......
...@@ -16,7 +16,7 @@ from pytensor.tensor import TensorVariable, get_vector_length ...@@ -16,7 +16,7 @@ from pytensor.tensor import TensorVariable, get_vector_length
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
@numba_funcify.register @numba_funcify.register(BlockwiseWithCoreShape)
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
[blockwise_node] = op.fgraph.apply_nodes [blockwise_node] = op.fgraph.apply_nodes
blockwise_op: Blockwise = blockwise_node.op blockwise_op: Blockwise = blockwise_node.op
...@@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:]) core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
core_node = blockwise_op._create_dummy_core_node( core_node = blockwise_op._create_dummy_core_node(
cast(tuple[TensorVariable], blockwise_node.inputs) cast(tuple[TensorVariable], node.inputs[:nin]),
propagate_unbatched_core_inputs=True,
) )
core_op_fn = numba_funcify( core_op_fn = numba_funcify(
core_op, core_op,
......
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, cast from typing import Any, Literal, cast, overload
import numpy as np import numpy as np
from numpy import broadcast_shapes, empty from numpy import broadcast_shapes, empty
...@@ -32,6 +32,17 @@ from pytensor.tensor.utils import ( ...@@ -32,6 +32,17 @@ from pytensor.tensor.utils import (
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def _squeeze_left(x, stop_at_dim: int | None = None):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims = x.type.broadcastable
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
if stop_at_dim is not None:
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
if squeeze_ndim == 0:
return x
return x.squeeze(axis=tuple(range(squeeze_ndim)))
def _vectorize_node_perform( def _vectorize_node_perform(
core_node: Apply, core_node: Apply,
batch_bcast_patterns: Sequence[tuple[bool, ...]], batch_bcast_patterns: Sequence[tuple[bool, ...]],
...@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ ...@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
class Blockwise(COp): class Blockwise(COp):
"""Generalizes a core `Op` to work with batched dimensions. """Generalizes a core `Op` to work with batched dimensions.
TODO: Dispatch JAX (should be easy with the vectorize macro)
TODO: Dispatch Numba
TODO: C implementation? TODO: C implementation?
TODO: Fuse Blockwise? TODO: Fuse Blockwise?
""" """
...@@ -202,21 +211,52 @@ class Blockwise(COp): ...@@ -202,21 +211,52 @@ class Blockwise(COp):
super().__init__(**kwargs) super().__init__(**kwargs)
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: @overload
core_input_types = [] def _create_dummy_core_node(
self,
inputs: Sequence[TensorVariable],
*,
propagate_unbatched_core_inputs: bool = False,
return_dummy_inputs: Literal[False] = ...,
) -> Apply: ...
@overload
def _create_dummy_core_node(
self,
inputs: Sequence[TensorVariable],
*,
propagate_unbatched_core_inputs: bool = False,
return_dummy_inputs: Literal[True] = ...,
) -> tuple[Apply, list[TensorVariable]]: ...
def _create_dummy_core_node(
self,
inputs: Sequence[TensorVariable],
*,
propagate_unbatched_core_inputs: bool = False,
return_dummy_inputs: bool = False,
) -> Apply | tuple[Apply, list[TensorVariable]]:
core_inputs = []
core_dummy_inputs = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
if inp.type.ndim < len(sig): if inp.type.ndim < len(sig):
raise ValueError( raise ValueError(
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
) )
# ndim_supp = 0 case # ndim_supp = 0 case
if not sig: inp_ndim = inp.type.ndim
core_shape = () batch_ndim = inp_ndim - len(sig)
core_shape = inp.type.shape[batch_ndim:]
if propagate_unbatched_core_inputs and all(
inp.type.broadcastable[:batch_ndim]
):
core_inputs.append(_squeeze_left(inp, batch_ndim))
else: else:
core_shape = inp.type.shape[-len(sig) :] dummy_inp = tensor(dtype=inp.type.dtype, shape=core_shape)
core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape)) core_inputs.append(dummy_inp)
core_dummy_inputs.append(dummy_inp)
core_node = self.core_op.make_node(*core_input_types) core_node = self.core_op.make_node(*core_inputs)
if len(core_node.outputs) != len(self.outputs_sig): if len(core_node.outputs) != len(self.outputs_sig):
raise ValueError( raise ValueError(
...@@ -230,6 +270,9 @@ class Blockwise(COp): ...@@ -230,6 +270,9 @@ class Blockwise(COp):
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
) )
if return_dummy_inputs:
return core_node, core_dummy_inputs
return core_node return core_node
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -298,11 +341,17 @@ class Blockwise(COp): ...@@ -298,11 +341,17 @@ class Blockwise(COp):
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
# Try to extract the core shapes from the core_op def extract_core_shape_from_infer_shape():
core_op_infer_shape = getattr(self.core_op, "infer_shape", None) # Try to extract the core shapes from the core_op
if core_op_infer_shape is not None: core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
dummy_core_node = self._create_dummy_core_node(node.inputs) if core_op_infer_shape is None:
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs)) return [[None] * out.ndim for out in node.outputs]
dummy_core_node, dummy_core_inputs = self._create_dummy_core_node(
node.inputs,
return_dummy_inputs=True,
propagate_unbatched_core_inputs=True,
)
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
core_input_shapes = [ core_input_shapes = [
input_shape[batch_ndims:] for input_shape in input_shapes input_shape[batch_ndims:] for input_shape in input_shapes
...@@ -311,6 +360,25 @@ class Blockwise(COp): ...@@ -311,6 +360,25 @@ class Blockwise(COp):
dummy_fgraph, dummy_core_node, core_input_shapes dummy_fgraph, dummy_core_node, core_input_shapes
) )
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
if not dummy_core_inputs:
# All inputs are unbatched, so the core_shape can be used as is
return core_output_shapes
else:
set_dummy_core_inputs = set(dummy_core_inputs)
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
for core_out_shape in safe_core_output_shapes:
for o, core_out_dim in enumerate(core_out_shape):
if set_dummy_core_inputs & set(
explicit_graph_inputs([core_out_dim])
):
core_out_shape[o] = None
return safe_core_output_shapes
safe_core_out_shape = None
out_shapes = [] out_shapes = []
for o, (output, sig) in enumerate( for o, (output, sig) in enumerate(
zip(node.outputs, self.outputs_sig, strict=True) zip(node.outputs, self.outputs_sig, strict=True)
...@@ -321,19 +389,15 @@ class Blockwise(COp): ...@@ -321,19 +389,15 @@ class Blockwise(COp):
if dim_name in core_dims: if dim_name in core_dims:
core_out_shape.append(core_dims[dim_name]) core_out_shape.append(core_dims[dim_name])
else: else:
if core_op_infer_shape is not None: if safe_core_out_shape is None:
# If the input values are needed to compute the dimension length, we can't use the infer_shape # Extract the core shape from the core_op infer_shape on demand
# of the core_node as the value is not constant across batch dims of the Blockwise # For many Ops we never need to do this, because all info is in their signature
core_out_dim = core_output_shapes[o][i] safe_core_out_shape = extract_core_shape_from_infer_shape()
if not ( if (core_out_dim := safe_core_out_shape[o][i]) is not None:
set(dummy_core_inputs) core_out_shape.append(core_out_dim)
& set(explicit_graph_inputs([core_out_dim])) else:
): # Fallback shape requires evaluating the Blockwise Op
core_out_shape.append(core_out_dim) core_out_shape.append(Shape_i(batch_ndims + i)(output))
continue
# Fallback shape requires evaluating the Blockwise Op
core_out_shape.append(Shape_i(batch_ndims + i)(output))
out_shapes.append((*batch_shape, *core_out_shape)) out_shapes.append((*batch_shape, *core_out_shape))
return out_shapes return out_shapes
...@@ -448,7 +512,10 @@ class Blockwise(COp): ...@@ -448,7 +512,10 @@ class Blockwise(COp):
) )
return core_func(*inputs) return core_func(*inputs)
else: else:
core_node = self._create_dummy_core_node(node.inputs) # type: ignore core_node = self._create_dummy_core_node(
cast(list[TensorVariable], node.inputs),
propagate_unbatched_core_inputs=True,
)
gufunc = _vectorize_node_perform( gufunc = _vectorize_node_perform(
core_node, core_node,
batch_bcast_patterns=batch_bcast_patterns, batch_bcast_patterns=batch_bcast_patterns,
......
...@@ -4,7 +4,7 @@ from pytensor.graph.destroyhandler import inplace_candidates ...@@ -4,7 +4,7 @@ from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
...@@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node): ...@@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
return local_useless_unbatched_blockwise.fn(fgraph, node) return local_useless_unbatched_blockwise.fn(fgraph, node)
def _squeeze_left(x, stop_at_dim: int | None = None):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims = x.type.broadcastable
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
if stop_at_dim is not None:
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
if squeeze_ndim == 0:
return x
return x.squeeze(axis=tuple(range(squeeze_ndim)))
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def local_blockwise_alloc(fgraph, node): def local_blockwise_alloc(fgraph, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论