提交 7eafb6c6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize shape operations

上级 7f0567a8
......@@ -204,7 +204,7 @@ def graph_replace(
@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError
......@@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
return _vectorize_node(op, node, *batched_inputs)
def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs)
@overload
def vectorize_graph(
outputs: Variable,
......
......@@ -8,7 +8,11 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.graph.replace import (
_vectorize_node,
_vectorize_not_needed,
vectorize_graph,
)
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
......@@ -37,17 +41,6 @@ def safe_signature(
return f"{inputs_sig}->{outputs_sig}"
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions.
......@@ -361,6 +354,15 @@ class Blockwise(Op):
return self.name
@_vectorize_node.register(Blockwise)
def vectorize_not_needed(op, node, *batch_inputs):
return op.make_node(*batch_inputs)
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
_vectorize_node.register(Blockwise, _vectorize_not_needed)
......@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
......@@ -22,7 +22,6 @@ from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import vectorize_not_needed
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
......@@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var):
raise ValueError(f"Length of {var} cannot be determined")
_vectorize_node.register(Elemwise, vectorize_not_needed)
_vectorize_node.register(Elemwise, _vectorize_not_needed)
@_vectorize_node.register(DimShuffle)
......
......@@ -8,6 +8,7 @@ import numpy as np
import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
......@@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim
_vectorize_node.register(Shape, _vectorize_not_needed)
def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
r"""Get a tuple of symbolic shape values.
......@@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var):
raise ValueError(f"Length of {var} cannot be determined")
@_vectorize_node.register(SpecifyShape)
def _vectorize_specify_shape(op, node, x, *shape):
old_x, *old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim
if any(
as_tensor_variable(dim).type.ndim != 0
for dim in shape
if not (NoneConst.equals(dim) or dim is None)
):
raise NotImplementedError(
"It is not possible to vectorize the shape argument of SpecifyShape"
)
if len(shape) == len(old_shape):
new_shape = tuple([None] * batched_ndims) + shape
elif len(shape) == (len(old_shape) + batched_ndims):
new_shape = shape
else:
raise ValueError(
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
)
return specify_shape(x, new_shape).owner
class Reshape(COp):
"""Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be
......@@ -638,7 +668,7 @@ class Reshape(COp):
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_, params):
def perform(self, node, inp, out_, params=None):
x, shp = inp
(out,) = out_
if len(shp) != self.ndim:
......@@ -770,6 +800,26 @@ class Reshape(COp):
"""
@_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape):
old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim
if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError(
"It is not possible to vectorize the shape argument of Reshape"
)
if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape]
elif len(tuple(old_shape)) == (len(tuple(shape)) - batched_ndims):
new_shape = shape
else:
raise ValueError("Invalid shape length passed into vectorize node of Reshape")
return reshape(x, new_shape, ndim=len(new_shape)).owner
def reshape(x, newshape, ndim=None):
if ndim is None:
newshape = at.as_tensor_variable(newshape)
......@@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes):
if not unbroadcasted_axes:
return x
return Unbroadcast(*unbroadcasted_axes)(x)
@_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return unbroadcast(x, *new_axes).owner
......@@ -5,14 +5,14 @@ import pytensor
from pytensor import Mode, function, grad
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, as_tensor, constant
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
......@@ -706,3 +706,88 @@ def test_shape_tuple():
assert isinstance(res[1], ScalarConstant)
assert res[1].data == 2
assert not isinstance(res[2], ScalarConstant)
class TestVectorize:
def test_shape(self):
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
node = shape(vec).owner
vect_node = vectorize_node(node, mat)
assert equal_computations(vect_node.outputs, [shape(mat)])
def test_reshape(self):
x = scalar("x", dtype=int)
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
shape = (2, x)
node = reshape(vec, shape).owner
vect_node = vectorize_node(node, mat, shape)
assert equal_computations(
vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))]
)
new_shape = (5, 2, x)
vect_node = vectorize_node(node, mat, new_shape)
assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)])
with pytest.raises(NotImplementedError):
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))
with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, vec, (5, 2, x))
with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, mat, (5, 3, 2, x))
def test_specify_shape(self):
x = scalar("x", dtype=int)
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
shape = (x, None)
node = specify_shape(mat, shape).owner
vect_node = vectorize_node(node, tns, *shape)
assert equal_computations(
vect_node.outputs, [specify_shape(tns, (None, x, None))]
)
new_shape = (5, 2, x)
vect_node = vectorize_node(node, tns, *new_shape)
assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))])
with pytest.raises(NotImplementedError):
vectorize_node(node, mat, *([x, x], None))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, mat, *(5, 2, x))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, tns, *(5, 3, 2, x))
def test_unbroadcast(self):
mat = tensor(
shape=(
1,
1,
)
)
tns = tensor(shape=(4, 1, 1, 1))
node = unbroadcast(mat, 0).owner
vect_node = vectorize_node(node, tns)
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论