提交 301f10dc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite to remove Blockwise of AdvancedIncSubtensor

上级 2e2c871e
......@@ -29,6 +29,7 @@ from pytensor.tensor.basic import (
register_infer_shape,
switch,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, add
......@@ -1880,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node):
copy_stack_trace(node.outputs, new_outs)
return new_outs
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_advanced_inc_subtensor(fgraph, node):
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
return None
x, y, *idxs = node.inputs
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
if any(
(
isinstance(idx, (SliceType, NoneTypeT))
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
)
for idx in idxs
):
return None
op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node)
new_idxs = []
for idx in idxs:
if all(idx.type.broadcastable[:batch_ndim]):
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
else:
# Rewrite does not apply
return None
x_batch_bcast = x.type.broadcastable[:batch_ndim]
y_batch_bcast = y.type.broadcastable[:batch_ndim]
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)):
# Need to broadcast batch x dims
batch_shape = tuple(
x_dim if (not xb or yb) else y_dim
for xb, x_dim, yb, y_dim in zip(
x_batch_bcast,
tuple(x.shape)[:batch_ndim],
y_batch_bcast,
tuple(y.shape)[:batch_ndim],
)
)
core_shape = tuple(x.shape)[batch_ndim:]
x = alloc(x, *batch_shape, *core_shape)
new_idxs = [slice(None)] * batch_ndim + new_idxs
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out)
return new_out
......@@ -9,7 +9,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -18,6 +18,7 @@ from pytensor.graph.type import Type
from pytensor.raise_op import Assert
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, add, dot, exp, sqr
from pytensor.tensor.rewriting.subtensor import (
......@@ -2314,3 +2315,98 @@ def test_local_uint_constant_indices():
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
@pytest.mark.parametrize("set_instead_of_inc", (True, False))
def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
core_x = tensor("x", shape=(6,))
core_y = tensor("y", shape=(3,))
core_idxs = [0, 2, 4]
if set_instead_of_inc:
core_graph = set_subtensor(core_x[core_idxs], core_y)
else:
core_graph = inc_subtensor(core_x[core_idxs], core_y)
# Only x is batched
x = tensor("x", shape=(5, 2, 6))
y = tensor("y", shape=(3,))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype)
expected_out = test_x.copy()
if set_instead_of_inc:
expected_out[:, :, core_idxs] = test_y
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Only y is batched
x = tensor("y", shape=(6,))
y = tensor("y", shape=(2, 3))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype)
expected_out = np.ones((2, *x.type.shape))
if set_instead_of_inc:
expected_out[:, core_idxs] = test_y
else:
expected_out[:, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, and do not need to be broadcasted
x = tensor("y", shape=(2, 6))
y = tensor("y", shape=(2, 3))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype)
expected_out = test_x.copy()
if set_instead_of_inc:
expected_out[:, core_idxs] = test_y
else:
expected_out[:, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, but must be broadcasted
x = tensor("y", shape=(5, 1, 6))
y = tensor("y", shape=(1, 2, 3))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype)
final_shape = (
*np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]),
x.type.shape[-1],
)
expected_out = np.broadcast_to(test_x, final_shape).copy()
if set_instead_of_inc:
expected_out[:, :, core_idxs] = test_y
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论