提交 65967fe2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement vectorize_node for Softmax and Argmax Ops

Also refactors shared logic for other batch axed Ops
上级 08a9ba35
...@@ -43,7 +43,12 @@ from pytensor.tensor import ( ...@@ -43,7 +43,12 @@ from pytensor.tensor import (
get_vector_length, get_vector_length,
) )
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.elemwise import (
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
)
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
...@@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1): ...@@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
@_vectorize_node.register(ExtractDiag) @_vectorize_node.register(ExtractDiag)
def vectorize_extract_diag(op: ExtractDiag, node, batched_x): def vectorize_extract_diag(op: ExtractDiag, node, batch_x):
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
batch_axis1, batch_axis2 = get_normalized_batch_axes(
(op.axis1, op.axis2), core_ndim, batch_ndim
)
return diagonal( return diagonal(
batched_x, batch_x,
offset=op.offset, offset=op.offset,
axis1=op.axis1 + batched_ndims, axis1=batch_axis1,
axis2=op.axis2 + batched_ndims, axis2=batch_axis2,
).owner ).owner
......
from copy import copy from copy import copy
from typing import Union
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple
import pytensor.tensor.basic import pytensor.tensor.basic
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -1399,7 +1401,7 @@ class CAReduce(COp): ...@@ -1399,7 +1401,7 @@ class CAReduce(COp):
# scalar inputs are treated as 1D regarding axis in this `Op` # scalar inputs are treated as 1D regarding axis in this `Op`
if axis is not None: if axis is not None:
try: try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims)) axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims))
except np.AxisError: except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims) raise np.AxisError(axis, ndim=inp_dims)
...@@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl ...@@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
return DimShuffle(input_broadcastable, new_order).make_node(x) return DimShuffle(input_broadcastable, new_order).make_node(x)
@_vectorize_node.register(CAReduce) def get_normalized_batch_axes(
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: core_axes: Union[None, int, tuple[int, ...]],
batched_ndims = x.type.ndim - node.inputs[0].type.ndim core_ndim: int,
if not batched_ndims: batch_ndim: int,
return node.op.make_node(x) ) -> tuple[int, ...]:
axes = op.axis """Compute batch axes for a batched operation, from the core input ndim and axes.
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
if axes is None: batch_axes(None, 2, 4) -> (2, 3)
axes = list(range(node.inputs[0].type.ndim))
e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
batch_axes(0, 2, 4) -> (2,)
e.g., sum(tensor3, axis=(0, -1)) -> sum(tensor4, axis=(1, 3))
batch_axes((0, -1), 3, 4) -> (1, 3)
"""
if core_axes is None:
core_axes = tuple(range(core_ndim))
else: else:
axes = list(axes) core_axes = normalize_axis_tuple(core_axes, core_ndim)
new_axes = [axis + batched_ndims for axis in axes] return tuple(core_axis + batch_ndim for core_axis in core_axes)
new_op = op.clone(axis=new_axes)
return new_op.make_node(x)
@_vectorize_node.register(CAReduce)
def vectorize_careduce(op: CAReduce, node: Apply, batch_x: TensorVariable) -> Apply:
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
if not batch_ndim:
return node.op.make_node(batch_x)
batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
return op.clone(axis=batch_axes).make_node(batch_x)
...@@ -27,7 +27,13 @@ from pytensor.tensor.basic import ( ...@@ -27,7 +27,13 @@ from pytensor.tensor.basic import (
switch, switch,
) )
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.elemwise import (
CAReduce,
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
)
from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
DenseTensorType, DenseTensorType,
...@@ -134,7 +140,7 @@ class MaxAndArgmax(COp): ...@@ -134,7 +140,7 @@ class MaxAndArgmax(COp):
_f16_ok = True _f16_ok = True
def __init__(self, axis): def __init__(self, axis):
assert isinstance(axis, list) assert isinstance(axis, (tuple, list))
self.axis = tuple(axis) self.axis = tuple(axis)
def get_params(self, node): def get_params(self, node):
...@@ -465,6 +471,19 @@ class Argmax(COp): ...@@ -465,6 +471,19 @@ class Argmax(COp):
return [x.zeros_like()] return [x.zeros_like()]
@_vectorize_node.register(Argmax)
@_vectorize_node.register(MaxAndArgmax)
def vectorize_argmax_node(op, node, batch_x):
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
if not batch_ndim:
return node.op.make_node(batch_x)
batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
return type(op)(axis=batch_axes).make_node(batch_x)
def makeKeepDims(x, y, axis): def makeKeepDims(x, y, axis):
""" """
Reintroduces in y with length one the axes of x which have been left out Reintroduces in y with length one the axes of x which have been left out
......
...@@ -18,6 +18,7 @@ from pytensor.scalar import int32 ...@@ -18,6 +18,7 @@ from pytensor.scalar import int32
from pytensor.tensor import _get_vector_length, as_tensor_variable from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
...@@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes): ...@@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes):
@_vectorize_node.register(Unbroadcast) @_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(op: Unbroadcast, node: Apply, x: TensorVariable) -> Apply: def _vectorize_unbroadcast(
batched_ndims = x.type.ndim - node.inputs[0].type.ndim op: Unbroadcast, node: Apply, batch_x: TensorVariable
old_axes = op.axes ) -> Apply:
new_axes = (old_axis + batched_ndims for old_axis in old_axes) core_ndim = node.inputs[0].type.ndim
return cast(Apply, unbroadcast(x, *new_axes).owner) batch_ndim = batch_x.type.ndim - core_ndim
batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim)
return cast(Apply, unbroadcast(batch_x, *batch_axes).owner)
...@@ -4,8 +4,10 @@ import numpy as np ...@@ -4,8 +4,10 @@ import numpy as np
import scipy import scipy
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.math import gamma, gammaln, neg, sum from pytensor.tensor.math import gamma, gammaln, neg, sum
...@@ -736,6 +738,32 @@ def log_softmax(c, axis=None): ...@@ -736,6 +738,32 @@ def log_softmax(c, axis=None):
return LogSoftmax(axis=axis)(c) return LogSoftmax(axis=axis)(c)
@_vectorize_node.register(Softmax)
@_vectorize_node.register(LogSoftmax)
def vectorize_softmax_node(op, node, batched_x):
"""
Vectorize Softmax and LogSoftmax nodes.
"""
core_ndim = node.inputs[0].type.ndim
batch_ndim = batched_x.type.ndim - core_ndim
if not batch_ndim:
return op.make_node(batched_x)
batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim)
if len(batch_axes) > 1:
from pytensor.tensor.blockwise import vectorize_node_fallback
# The softmax Ops only allow a specific axis (integer) or all axis (None).
# If the vectorized operation requires more than one axis we have to default to a Blockwise
return vectorize_node_fallback(op, node, batched_x)
[batch_axis] = batch_axes
return type(op)(axis=batch_axis).make_node(batched_x)
def poch(z, m): def poch(z, m):
""" """
Pochhammer symbol (rising factorial) function. Pochhammer symbol (rising factorial) function.
......
...@@ -20,6 +20,7 @@ from pytensor.configdefaults import config ...@@ -20,6 +20,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, applys_between from pytensor.graph.basic import Variable, applys_between
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import pprint from pytensor.printing import pprint
...@@ -1010,6 +1011,35 @@ class TestMaxAndArgmax: ...@@ -1010,6 +1011,35 @@ class TestMaxAndArgmax:
assert max_pt.eval() == 3 assert max_pt.eval() == 3
assert argmax_pt.eval() == 2 assert argmax_pt.eval() == 2
@pytest.mark.parametrize(
"core_axis, batch_axis",
[
(None, (1, 2, 3, 4)),
(0, (1,)),
((1, -1), (2, 4)),
],
)
def test_vectorize(self, core_axis, batch_axis):
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))
# Test MaxAndArgmax
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
node = max_x.owner
assert isinstance(node.op, MaxAndArgmax)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, MaxAndArgmax)
assert new_node.op.axis == batch_axis
# Test Argmax
# Argmax is not user-facing, so we have to create it manually
node = Argmax(axis=node.op.axis).make_node(x)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis
class TestArgminArgmax: class TestArgminArgmax:
def setup_method(self): def setup_method(self):
......
...@@ -8,6 +8,8 @@ from scipy.special import softmax as scipy_softmax ...@@ -8,6 +8,8 @@ from scipy.special import softmax as scipy_softmax
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.replace import vectorize_node
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.special import ( from pytensor.tensor.special import (
LogSoftmax, LogSoftmax,
Softmax, Softmax,
...@@ -19,7 +21,7 @@ from pytensor.tensor.special import ( ...@@ -19,7 +21,7 @@ from pytensor.tensor.special import (
poch, poch,
softmax, softmax,
) )
from pytensor.tensor.type import matrix, tensor3, tensor4, vector, vectors from pytensor.tensor.type import matrix, tensor, tensor3, tensor4, vector, vectors
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import random_ranged from tests.tensor.utils import random_ranged
...@@ -150,6 +152,34 @@ class TestSoftmaxGrad(utt.InferShapeTester): ...@@ -150,6 +152,34 @@ class TestSoftmaxGrad(utt.InferShapeTester):
SoftmaxGrad(-4)(*x) SoftmaxGrad(-4)(*x)
@pytest.mark.parametrize(
"core_axis, batch_axis",
[
(None, (1, 2, 3, 4)),
(0, (1,)),
],
)
@pytest.mark.parametrize(
"op, constructor", [(Softmax, softmax), (LogSoftmax, log_softmax)]
)
def test_vectorize_softmax(op, constructor, core_axis, batch_axis):
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))
node = constructor(x, axis=core_axis).owner
assert isinstance(node.op, op)
new_node = vectorize_node(node, batch_x)
if len(batch_axis) == 1:
assert isinstance(new_node.op, op)
assert (new_node.op.axis,) == batch_axis
else:
assert isinstance(new_node.op, Blockwise) and isinstance(
new_node.op.core_op, op
)
assert new_node.op.core_op.axis == core_axis
def test_poch(): def test_poch():
_z, _m = vectors("z", "m") _z, _m = vectors("z", "m")
actual_fn = function([_z, _m], poch(_z, _m)) actual_fn = function([_z, _m], poch(_z, _m))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论