提交 fe06ee32 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize Subtensor without batched indices

上级 2c41735a
......@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp
......@@ -22,6 +23,7 @@ from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
......@@ -1283,6 +1285,21 @@ class SubtensorPrinter(Printer):
pprint.assign(Subtensor, SubtensorPrinter())
# TODO: Implement similar vectorize for Inc/SetSubtensor
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
"""
Return x with the given subtensor overwritten by y.
......
......@@ -9,6 +9,7 @@ from numpy.testing import assert_array_equal
import pytensor
import pytensor.scalar as scal
import pytensor.tensor.basic as at
from pytensor import function
from pytensor.compile import DeepCopyOp, shared
from pytensor.compile.io import In
from pytensor.configdefaults import config
......@@ -16,7 +17,8 @@ from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar
from pytensor.tensor import get_vector_length
from pytensor.tensor import get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf
from pytensor.tensor.math import sum as at_sum
......@@ -2709,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected):
x = at.tensor(dtype="float64", shape=x_shape)
y = x[indices]
assert y.type.shape == expected
def test_vectorize_subtensor_without_batch_indices():
signature = "(t1,t2,t3),()->(t1,t3)"
def core_fn(x, start):
return x[:, start, :]
x = tensor(shape=(11, 7, 5, 3))
start = tensor(shape=(), dtype="int")
vectorize_pt = function(
[x, start], vectorize(core_fn, signature=signature)(x, start)
)
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
start_test = np.random.randint(0, x.type.shape[-2])
vectorize_np = np.vectorize(core_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test),
)
# If we vectorize start, we should get a Blockwise that still works
x = tensor(shape=(11, 7, 5, 3))
start = tensor(shape=(11,), dtype="int")
vectorize_pt = function(
[x, start], vectorize(core_fn, signature=signature)(x, start)
)
assert any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0])
vectorize_np = np.vectorize(core_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论