提交 7eafb6c6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize shape operations

上级 7f0567a8
...@@ -204,7 +204,7 @@ def graph_replace( ...@@ -204,7 +204,7 @@ def graph_replace(
@singledispatch @singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise # Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError raise NotImplementedError
...@@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply: ...@@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
return _vectorize_node(op, node, *batched_inputs) return _vectorize_node(op, node, *batched_inputs)
def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs)
@overload @overload
def vectorize_graph( def vectorize_graph(
outputs: Variable, outputs: Variable,
......
...@@ -8,7 +8,11 @@ from pytensor.gradient import DisconnectedType ...@@ -8,7 +8,11 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize_graph from pytensor.graph.replace import (
_vectorize_node,
_vectorize_not_needed,
vectorize_graph,
)
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
...@@ -37,17 +41,6 @@ def safe_signature( ...@@ -37,17 +41,6 @@ def safe_signature(
return f"{inputs_sig}->{outputs_sig}" return f"{inputs_sig}->{outputs_sig}"
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
class Blockwise(Op): class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions. """Generalizes a core `Op` to work with batched dimensions.
...@@ -361,6 +354,15 @@ class Blockwise(Op): ...@@ -361,6 +354,15 @@ class Blockwise(Op):
return self.name return self.name
@_vectorize_node.register(Blockwise) @_vectorize_node.register(Op)
def vectorize_not_needed(op, node, *batch_inputs): def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
return op.make_node(*batch_inputs) if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
_vectorize_node.register(Blockwise, _vectorize_not_needed)
...@@ -7,7 +7,7 @@ from pytensor.configdefaults import config ...@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
...@@ -22,7 +22,6 @@ from pytensor.scalar.basic import transfer_type, upcast ...@@ -22,7 +22,6 @@ from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import vectorize_not_needed
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
continuous_dtypes, continuous_dtypes,
...@@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var): ...@@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var):
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
_vectorize_node.register(Elemwise, vectorize_not_needed) _vectorize_node.register(Elemwise, _vectorize_not_needed)
@_vectorize_node.register(DimShuffle) @_vectorize_node.register(DimShuffle)
......
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import pytensor import pytensor
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
...@@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var): ...@@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim return var.owner.inputs[0].type.ndim
_vectorize_node.register(Shape, _vectorize_not_needed)
def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
r"""Get a tuple of symbolic shape values. r"""Get a tuple of symbolic shape values.
...@@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var): ...@@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var):
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
@_vectorize_node.register(SpecifyShape)
def _vectorize_specify_shape(op, node, x, *shape):
old_x, *old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim
if any(
as_tensor_variable(dim).type.ndim != 0
for dim in shape
if not (NoneConst.equals(dim) or dim is None)
):
raise NotImplementedError(
"It is not possible to vectorize the shape argument of SpecifyShape"
)
if len(shape) == len(old_shape):
new_shape = tuple([None] * batched_ndims) + shape
elif len(shape) == (len(old_shape) + batched_ndims):
new_shape = shape
else:
raise ValueError(
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
)
return specify_shape(x, new_shape).owner
class Reshape(COp): class Reshape(COp):
"""Perform a reshape operation of the input x to the new shape shp. """Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be The number of dimensions to which to reshape to (ndim) must be
...@@ -638,7 +668,7 @@ class Reshape(COp): ...@@ -638,7 +668,7 @@ class Reshape(COp):
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)]) return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_, params): def perform(self, node, inp, out_, params=None):
x, shp = inp x, shp = inp
(out,) = out_ (out,) = out_
if len(shp) != self.ndim: if len(shp) != self.ndim:
...@@ -770,6 +800,26 @@ class Reshape(COp): ...@@ -770,6 +800,26 @@ class Reshape(COp):
""" """
@_vectorize_node.register(Reshape)
def _vectorize_reshape(op, node, x, shape):
old_x, old_shape = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim
if as_tensor_variable(shape).type.ndim != 1:
raise NotImplementedError(
"It is not possible to vectorize the shape argument of Reshape"
)
if len(tuple(old_shape)) == len(tuple(shape)):
new_shape = [*x.shape[:batched_ndims], *shape]
elif len(tuple(old_shape)) == (len(tuple(shape)) - batched_ndims):
new_shape = shape
else:
raise ValueError("Invalid shape length passed into vectorize node of Reshape")
return reshape(x, new_shape, ndim=len(new_shape)).owner
def reshape(x, newshape, ndim=None): def reshape(x, newshape, ndim=None):
if ndim is None: if ndim is None:
newshape = at.as_tensor_variable(newshape) newshape = at.as_tensor_variable(newshape)
...@@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes): ...@@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes):
if not unbroadcasted_axes: if not unbroadcasted_axes:
return x return x
return Unbroadcast(*unbroadcasted_axes)(x) return Unbroadcast(*unbroadcasted_axes)(x)
@_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
old_axes = op.axes
new_axes = (old_axis + batched_ndims for old_axis in old_axes)
return unbroadcast(x, *new_axes).owner
...@@ -5,14 +5,14 @@ import pytensor ...@@ -5,14 +5,14 @@ import pytensor
from pytensor import Mode, function, grad from pytensor import Mode, function, grad
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant from pytensor.tensor.basic import MakeVector, as_tensor, constant
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
...@@ -706,3 +706,88 @@ def test_shape_tuple(): ...@@ -706,3 +706,88 @@ def test_shape_tuple():
assert isinstance(res[1], ScalarConstant) assert isinstance(res[1], ScalarConstant)
assert res[1].data == 2 assert res[1].data == 2
assert not isinstance(res[2], ScalarConstant) assert not isinstance(res[2], ScalarConstant)
class TestVectorize:
def test_shape(self):
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
node = shape(vec).owner
vect_node = vectorize_node(node, mat)
assert equal_computations(vect_node.outputs, [shape(mat)])
def test_reshape(self):
x = scalar("x", dtype=int)
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
shape = (2, x)
node = reshape(vec, shape).owner
vect_node = vectorize_node(node, mat, shape)
assert equal_computations(
vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))]
)
new_shape = (5, 2, x)
vect_node = vectorize_node(node, mat, new_shape)
assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)])
with pytest.raises(NotImplementedError):
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))
with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, vec, (5, 2, x))
with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, mat, (5, 3, 2, x))
def test_specify_shape(self):
x = scalar("x", dtype=int)
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
shape = (x, None)
node = specify_shape(mat, shape).owner
vect_node = vectorize_node(node, tns, *shape)
assert equal_computations(
vect_node.outputs, [specify_shape(tns, (None, x, None))]
)
new_shape = (5, 2, x)
vect_node = vectorize_node(node, tns, *new_shape)
assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))])
with pytest.raises(NotImplementedError):
vectorize_node(node, mat, *([x, x], None))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, mat, *(5, 2, x))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, tns, *(5, 3, 2, x))
def test_unbroadcast(self):
mat = tensor(
shape=(
1,
1,
)
)
tns = tensor(shape=(4, 1, 1, 1))
node = unbroadcast(mat, 0).owner
vect_node = vectorize_node(node, tns)
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论