提交 301f10dc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite to remove Blockwise of AdvancedIncSubtensor

上级 2e2c871e
...@@ -29,6 +29,7 @@ from pytensor.tensor.basic import ( ...@@ -29,6 +29,7 @@ from pytensor.tensor.basic import (
register_infer_shape, register_infer_shape,
switch, switch,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, add from pytensor.tensor.math import Dot, add
...@@ -1880,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1880,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node):
copy_stack_trace(node.outputs, new_outs) copy_stack_trace(node.outputs, new_outs)
return new_outs return new_outs
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_advanced_inc_subtensor(fgraph, node):
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
return None
x, y, *idxs = node.inputs
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
if any(
(
isinstance(idx, (SliceType, NoneTypeT))
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
)
for idx in idxs
):
return None
op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node)
new_idxs = []
for idx in idxs:
if all(idx.type.broadcastable[:batch_ndim]):
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
else:
# Rewrite does not apply
return None
x_batch_bcast = x.type.broadcastable[:batch_ndim]
y_batch_bcast = y.type.broadcastable[:batch_ndim]
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)):
# Need to broadcast batch x dims
batch_shape = tuple(
x_dim if (not xb or yb) else y_dim
for xb, x_dim, yb, y_dim in zip(
x_batch_bcast,
tuple(x.shape)[:batch_ndim],
y_batch_bcast,
tuple(y.shape)[:batch_ndim],
)
)
core_shape = tuple(x.shape)[batch_ndim:]
x = alloc(x, *batch_shape, *core_shape)
new_idxs = [slice(None)] * batch_ndim + new_idxs
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out)
return new_out
...@@ -9,7 +9,7 @@ from pytensor.compile.function import function ...@@ -9,7 +9,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -18,6 +18,7 @@ from pytensor.graph.type import Type ...@@ -18,6 +18,7 @@ from pytensor.graph.type import Type
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor import inplace from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, add, dot, exp, sqr from pytensor.tensor.math import Dot, add, dot, exp, sqr
from pytensor.tensor.rewriting.subtensor import ( from pytensor.tensor.rewriting.subtensor import (
...@@ -2314,3 +2315,98 @@ def test_local_uint_constant_indices(): ...@@ -2314,3 +2315,98 @@ def test_local_uint_constant_indices():
new_index = subtensor_node.inputs[1] new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
@pytest.mark.parametrize("set_instead_of_inc", (True, False))
def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
core_x = tensor("x", shape=(6,))
core_y = tensor("y", shape=(3,))
core_idxs = [0, 2, 4]
if set_instead_of_inc:
core_graph = set_subtensor(core_x[core_idxs], core_y)
else:
core_graph = inc_subtensor(core_x[core_idxs], core_y)
# Only x is batched
x = tensor("x", shape=(5, 2, 6))
y = tensor("y", shape=(3,))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype)
expected_out = test_x.copy()
if set_instead_of_inc:
expected_out[:, :, core_idxs] = test_y
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Only y is batched
x = tensor("y", shape=(6,))
y = tensor("y", shape=(2, 3))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype)
expected_out = np.ones((2, *x.type.shape))
if set_instead_of_inc:
expected_out[:, core_idxs] = test_y
else:
expected_out[:, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, and do not need to be broadcasted
x = tensor("y", shape=(2, 6))
y = tensor("y", shape=(2, 3))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype)
expected_out = test_x.copy()
if set_instead_of_inc:
expected_out[:, core_idxs] = test_y
else:
expected_out[:, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, but must be broadcasted
x = tensor("y", shape=(5, 1, 6))
y = tensor("y", shape=(1, 2, 3))
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype)
final_shape = (
*np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]),
x.type.shape[-1],
)
expected_out = np.broadcast_to(test_x, final_shape).copy()
if set_instead_of_inc:
expected_out[:, :, core_idxs] = test_y
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论