提交 fe06ee32 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize Subtensor without batched indices

上级 2c41735a
...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config ...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
...@@ -22,6 +23,7 @@ from pytensor.printing import Printer, pprint, set_precedence ...@@ -22,6 +23,7 @@ from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip from pytensor.tensor.math import clip
...@@ -1283,6 +1285,21 @@ class SubtensorPrinter(Printer): ...@@ -1283,6 +1285,21 @@ class SubtensorPrinter(Printer):
pprint.assign(Subtensor, SubtensorPrinter()) pprint.assign(Subtensor, SubtensorPrinter())
# TODO: Implement similar vectorize for Inc/SetSubtensor
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False): def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
""" """
Return x with the given subtensor overwritten by y. Return x with the given subtensor overwritten by y.
......
...@@ -9,6 +9,7 @@ from numpy.testing import assert_array_equal ...@@ -9,6 +9,7 @@ from numpy.testing import assert_array_equal
import pytensor import pytensor
import pytensor.scalar as scal import pytensor.scalar as scal
import pytensor.tensor.basic as at import pytensor.tensor.basic as at
from pytensor import function
from pytensor.compile import DeepCopyOp, shared from pytensor.compile import DeepCopyOp, shared
from pytensor.compile.io import In from pytensor.compile.io import In
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -16,7 +17,8 @@ from pytensor.graph.op import get_test_value ...@@ -16,7 +17,8 @@ from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar from pytensor.scalar.basic import as_scalar
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import exp, isinf
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
...@@ -2709,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected): ...@@ -2709,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected):
x = at.tensor(dtype="float64", shape=x_shape) x = at.tensor(dtype="float64", shape=x_shape)
y = x[indices] y = x[indices]
assert y.type.shape == expected assert y.type.shape == expected
def test_vectorize_subtensor_without_batch_indices():
signature = "(t1,t2,t3),()->(t1,t3)"
def core_fn(x, start):
return x[:, start, :]
x = tensor(shape=(11, 7, 5, 3))
start = tensor(shape=(), dtype="int")
vectorize_pt = function(
[x, start], vectorize(core_fn, signature=signature)(x, start)
)
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
start_test = np.random.randint(0, x.type.shape[-2])
vectorize_np = np.vectorize(core_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test),
)
# If we vectorize start, we should get a Blockwise that still works
x = tensor(shape=(11, 7, 5, 3))
start = tensor(shape=(11,), dtype="int")
vectorize_pt = function(
[x, start], vectorize(core_fn, signature=signature)(x, start)
)
assert any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0])
vectorize_np = np.vectorize(core_fn, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, start_test),
vectorize_np(x_test, start_test),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论