Unverified 提交 a0225e1f authored 作者: Carlos Trujillo's avatar Carlos Trujillo 提交者: GitHub

Fix Blockwise vmap dispatch for no batch dimensions (#1705)

* Fix Blockwise vmap dispatch for no batch dimensions Updates comments in funcify_Blockwise to avoid confusion about behaviour. Adds tests to verify correct behavior for these cases. * pre-commit * Take our docs
上级 1f9a67bc
...@@ -6,30 +6,37 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -6,30 +6,37 @@ from pytensor.tensor.blockwise import Blockwise
@mlx_funcify.register(Blockwise) @mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs): def funcify_Blockwise(op: Blockwise, node, **kwargs):
# 2) Otherwise, get the core python function for this Blockwise # Get the core python function for this Blockwise operation
core_node = op._create_dummy_core_node(node.inputs) core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node) core_f = mlx_funcify(op.core_op, core_node)
# 3) Determine how many inputs correspond to batch dimensions # Determine how many batch dimensions are present in the output
n_batch = op.batch_ndim(node) n_batch = op.batch_ndim(node)
# 4) Handle case where no vectorization is needed # If there are no batch dimensions, just return the core function
if n_batch == 0: if n_batch == 0:
return core_f return core_f
# 5) Vectorize using mx.vmap over any batched inputs # Build in_axes specification for mx.vmap
# Each input can be vectorized (axis=0) or static (axis=None)
in_axes: list[int | None] = [] in_axes: list[int | None] = []
for inp, sig in zip(node.inputs, op.inputs_sig): for inp, sig in zip(node.inputs, op.inputs_sig):
batch_ndim = inp.type.ndim - len(sig) batch_ndim = inp.type.ndim - len(sig)
if batch_ndim == 0: if batch_ndim == 0:
# Input has no batch dimensions - treat as static
in_axes.append(None) in_axes.append(None)
continue continue
batch_bcast = inp.type.broadcastable[:batch_ndim] batch_bcast = inp.type.broadcastable[:batch_ndim]
# If all batch dims are broadcastable (size 1), treat input as static # If all batch dims are broadcastable (size 1), treat input as static
# Otherwise, vectorize over the first dimension (axis=0)
in_axes.append(0 if not all(batch_bcast) else None) in_axes.append(0 if not all(batch_bcast) else None)
# If all inputs are static (no actual vectorization needed), return core function
# This prevents calling mx.vmap with all-None in_axes, which would raise:
# "ValueError: At least one of in_axes must be non-None"
if not any(axis == 0 for axis in in_axes): if not any(axis == 0 for axis in in_axes):
return core_f return core_f
# Apply mx.vmap to vectorize the core function over batch dimensions
return mx.vmap(core_f, in_axes=tuple(in_axes)) return mx.vmap(core_f, in_axes=tuple(in_axes))
...@@ -25,3 +25,45 @@ def test_blockwise_conv1d(): ...@@ -25,3 +25,45 @@ def test_blockwise_conv1d():
# assert isinstance(out.owner.op, Blockwise) # assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True) compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)
def test_blockwise_no_batch_dimensions():
"""Test that Blockwise returns the core function when there are no batch dimensions.
This verifies the fix for the vmap dispatcher issue where mx.vmap should not
be called when there are no batch dimensions to vectorize over.
"""
rng = np.random.default_rng(42)
# Create a blockwise matmul with no batch dimensions (core operation only)
x = pt.matrix("x")
y = pt.matrix("y")
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
z = blockwise_matmul(x, y)
x_test = rng.normal(size=(2, 3))
y_test = rng.normal(size=(3, 4))
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)
def test_blockwise_all_broadcastable_batch_dims():
"""Test that Blockwise returns the core function when all batch dims are broadcastable.
When all batch dimensions are size-1 (broadcastable), vmap should not be called
since there's no actual vectorization needed.
"""
rng = np.random.default_rng(43)
# Create inputs with size-1 batch dimensions
x = tensor("x", shape=(1, 2, 3))
y = tensor("y", shape=(1, 3, 4))
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
z = blockwise_matmul(x, y)
x_test = rng.normal(size=(1, 2, 3))
y_test = rng.normal(size=(1, 3, 4))
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论