提交 49f76da9 authored 作者: copilot-swe-agent[bot]'s avatar copilot-swe-agent[bot] 提交者: Ricardo Vieira

Implement axis=None raveling behavior symbolically in CumOp

上级 9b522a86
...@@ -41,21 +41,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -41,21 +41,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
mode = op.mode mode = op.mode
ndim = cast(TensorVariable, node.outputs[0]).ndim ndim = cast(TensorVariable, node.outputs[0]).ndim
if axis is not None:
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
reaxis_first = (axis, *(i for i in range(ndim) if i != axis)) reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
reaxis_first_inv = tuple(np.argsort(reaxis_first)) reaxis_first_inv = tuple(np.argsort(reaxis_first))
if mode == "add": if mode == "add":
if axis is None or ndim == 1: if ndim == 1:
@numba_basic.numba_njit @numba_basic.numba_njit
def cumop(x): def cumop(x):
return np.cumsum(x) return np.cumsum(x, axis=axis)
else: else:
...@@ -75,11 +69,11 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -75,11 +69,11 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
return res.transpose(reaxis_first_inv) return res.transpose(reaxis_first_inv)
else: else:
if axis is None or ndim == 1: if ndim == 1:
@numba_basic.numba_njit @numba_basic.numba_njit
def cumop(x): def cumop(x):
return np.cumprod(x) return np.cumprod(x, axis=axis)
else: else:
...@@ -96,7 +90,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): ...@@ -96,7 +90,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
for m in range(1, x.shape[axis]): for m in range(1, x.shape[axis]):
res[m] = res[m - 1] * x_axis_first[m] res[m] = res[m - 1] * x_axis_first[m]
return res.transpose(reaxis_first) return res.transpose(reaxis_first_inv)
return cumop return cumop
......
...@@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs): ...@@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
mode = op.mode mode = op.mode
def cumop(x): def cumop(x):
if axis is None:
x = x.reshape(-1)
dim = 0
else:
dim = axis
if mode == "add": if mode == "add":
return torch.cumsum(x, dim=dim) return torch.cumsum(x, dim=axis)
else: else:
return torch.cumprod(x, dim=dim) return torch.cumprod(x, dim=axis)
return cumop return cumop
......
import warnings import warnings
from collections.abc import Collection, Iterable from collections.abc import Collection, Iterable
from textwrap import dedent
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_index from numpy.lib.array_utils import normalize_axis_index
...@@ -44,10 +45,10 @@ from pytensor.tensor.math import max as pt_max ...@@ -44,10 +45,10 @@ from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import LOCAL_BITWIDTH, NPY_RAVEL_AXIS, PYTHON_INT_BITWIDTH from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
class CpuContiguous(COp): class CpuContiguous(COp):
...@@ -290,33 +291,28 @@ class CumOp(COp): ...@@ -290,33 +291,28 @@ class CumOp(COp):
__props__ = ("axis", "mode") __props__ = ("axis", "mode")
check_input = False check_input = False
params_type = ParamsType( params_type = ParamsType(
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
) )
def __init__(self, axis: int | None = None, mode="add"): def __init__(self, axis: int, mode="add"):
if mode not in ("add", "mul"): if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
if not (isinstance(axis, int) or axis is None): if not isinstance(axis, int):
raise TypeError("axis must be an integer or None.") raise TypeError(f"axis must be an integer, got {axis} of type {type(axis)}")
if axis < 0:
raise ValueError(f"axis must be non-negative, got {axis}")
self.axis = axis self.axis = axis
self.mode = mode self.mode = mode
@property
def c_axis(self) -> int:
if self.axis is None:
return NPY_RAVEL_AXIS
return self.axis
def make_node(self, x): def make_node(self, x):
x = ptb.as_tensor_variable(x) x = ptb.as_tensor_variable(x)
out_type = x.type()
if self.axis is None: if self.axis >= x.type.ndim:
out_type = vector(dtype=x.dtype) # Flatten raise ValueError(
elif self.axis >= x.ndim or self.axis < -x.ndim: f"axis(={self.axis}) out of bounds for variable {x} with {x.type.ndim} ndims"
raise ValueError(f"axis(={self.axis}) out of bounds") )
return Apply(self, [x], [out_type]) return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x = inputs[0] x = inputs[0]
...@@ -326,21 +322,10 @@ class CumOp(COp): ...@@ -326,21 +322,10 @@ class CumOp(COp):
else: else:
z[0] = np.cumprod(x, axis=self.axis) z[0] = np.cumprod(x, axis=self.axis)
def grad(self, inputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs (x,) = inputs
(gi,) = output_gradients (gi,) = output_gradients
if self.axis is None:
if self.mode == "add":
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
)
reverse_slicing = [slice(None, None, None)] * gi.ndim reverse_slicing = [slice(None, None, None)] * gi.ndim
reverse_slicing[self.axis] = slice(None, None, -1) reverse_slicing[self.axis] = slice(None, None, -1)
reverse_slicing = tuple(reverse_slicing) reverse_slicing = tuple(reverse_slicing)
...@@ -357,9 +342,6 @@ class CumOp(COp): ...@@ -357,9 +342,6 @@ class CumOp(COp):
) )
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
if self.axis is None and len(shapes[0]) > 1:
return [(prod(shapes[0]),)] # Flatten
return shapes return shapes
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
...@@ -368,61 +350,43 @@ class CumOp(COp): ...@@ -368,61 +350,43 @@ class CumOp(COp):
fail = sub["fail"] fail = sub["fail"]
params = sub["params"] params = sub["params"]
if self.axis is None: return dedent(
axis_code = "int axis = NPY_RAVEL_AXIS;\n" f"""
else: int axis = {params}->axis;
axis_code = f"int axis = {params}->c_axis;\n"
code = (
axis_code
+ f"""
#undef NPY_UF_DBG_TRACING
#define NPY_UF_DBG_TRACING 1
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_RAVEL_AXIS;
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
}}
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{{ {{
Py_XDECREF({z}); Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
if (!{z}){{ {fail} }};
}} }}
if (!{z})
{fail};
{{ {{
PyObject * t = NULL; PyObject * t = NULL;
if({params}->mode == MODE_ADD) if({params}->mode == MODE_ADD)
t = PyArray_CumSum( t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z});
{x}, axis,
PyArray_TYPE({x}), {z});
else if({params}->mode == MODE_MUL) else if({params}->mode == MODE_MUL)
t = PyArray_CumProd( t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z});
{x}, axis,
PyArray_TYPE({x}), {z});
if (!t){{ if (!t){{
{fail}; {fail};
}} }}
// Because PyArray_CumSum/CumProd returns a newly created reference on t. // Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t); Py_XDECREF(t);
}} }}
""" """
) )
return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (10,) return (11,)
def __str__(self): def __str__(self):
if self.mode == "add":
return f"Cumsum{{axis={self.axis}}}"
elif self.mode == "mul":
return f"Cumprod{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}" return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"
...@@ -443,6 +407,12 @@ def cumsum(x, axis=None): ...@@ -443,6 +407,12 @@ def cumsum(x, axis=None):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
x = ptb.as_tensor_variable(x)
if axis is None:
x = x.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="add")(x) return CumOp(axis=axis, mode="add")(x)
...@@ -463,6 +433,12 @@ def cumprod(x, axis=None): ...@@ -463,6 +433,12 @@ def cumprod(x, axis=None):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
x = ptb.as_tensor_variable(x)
if axis is None:
x = x.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="mul")(x) return CumOp(axis=axis, mode="mul")(x)
...@@ -471,18 +447,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x): ...@@ -471,18 +447,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
"""Vectorize the CumOp to work on a batch of inputs.""" """Vectorize the CumOp to work on a batch of inputs."""
[original_x] = node.inputs [original_x] = node.inputs
batch_ndim = batch_x.ndim - original_x.ndim batch_ndim = batch_x.ndim - original_x.ndim
axis = op.axis # op.axis is already normalized and non-negative
if axis is None and original_x.ndim == 1: return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)
axis = 0
elif axis is not None:
axis = normalize_axis_index(op.axis, original_x.ndim)
if axis is None:
# Ravel all unbatched dimensions and perform CumOp on the last axis
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
else:
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
def diff(x, n=1, axis=-1): def diff(x, n=1, axis=-1):
......
...@@ -38,11 +38,6 @@ def test_Bartlett(val): ...@@ -38,11 +38,6 @@ def test_Bartlett(val):
1, 1,
"add", "add",
), ),
(
(pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))),
-1,
"add",
),
( (
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
0, 0,
...@@ -53,11 +48,6 @@ def test_Bartlett(val): ...@@ -53,11 +48,6 @@ def test_Bartlett(val):
1, 1,
"add", "add",
), ),
(
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
None,
"add",
),
( (
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))), (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
0, 0,
...@@ -68,11 +58,6 @@ def test_Bartlett(val): ...@@ -68,11 +58,6 @@ def test_Bartlett(val):
1, 1,
"mul", "mul",
), ),
(
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
None,
"mul",
),
], ],
) )
def test_CumOp(val, axis, mode): def test_CumOp(val, axis, mode):
......
...@@ -5,39 +5,13 @@ import pytensor.tensor as pt ...@@ -5,39 +5,13 @@ import pytensor.tensor as pt
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize( @pytest.mark.parametrize("dtype", ["float64", "int64"])
"dtype", @pytest.mark.parametrize("axis", [None, -1])
["float64", "int64"],
)
@pytest.mark.parametrize(
"axis",
[None, 1, (0,)],
)
def test_pytorch_CumOp(axis, dtype): def test_pytorch_CumOp(axis, dtype):
"""Test PyTorch conversion of the `CumOp` `Op`."""
# Create a symbolic input for the first input of `CumOp`
a = pt.matrix("a", dtype=dtype) a = pt.matrix("a", dtype=dtype)
# Create test value
test_value = np.arange(9, dtype=dtype).reshape((3, 3)) test_value = np.arange(9, dtype=dtype).reshape((3, 3))
outs = [pt.cumsum(a, axis=axis), pt.cumprod(a, axis=axis)]
# Create the output variable compare_pytorch_and_py([a], outs, [test_value])
if isinstance(axis, tuple):
with pytest.raises(TypeError, match="axis must be an integer or None\\."):
out = pt.cumsum(a, axis=axis)
with pytest.raises(TypeError, match="axis must be an integer or None\\."):
out = pt.cumprod(a, axis=axis)
else:
out = pt.cumsum(a, axis=axis)
# Pass the inputs and outputs to the testing function
compare_pytorch_and_py([a], [out], [test_value])
# For the second mode of CumOp
out = pt.cumprod(a, axis=axis)
compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)]) @pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
......
...@@ -195,7 +195,7 @@ class TestCumOp(utt.InferShapeTester): ...@@ -195,7 +195,7 @@ class TestCumOp(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.op_class = CumOp self.op_class = CumOp
self.op = CumOp() self.op = CumOp(axis=0)
def test_cum_op(self): def test_cum_op(self):
x = tensor3("x") x = tensor3("x")
...@@ -226,8 +226,8 @@ class TestCumOp(utt.InferShapeTester): ...@@ -226,8 +226,8 @@ class TestCumOp(utt.InferShapeTester):
x = tensor3("x") x = tensor3("x")
a = np.random.random((3, 5, 2)).astype(config.floatX) a = np.random.random((3, 5, 2)).astype(config.floatX)
# Test axis=None # Test default axis=None
self._compile_and_check([x], [self.op(x)], [a], self.op_class) self._compile_and_check([x], [cumsum(x)], [a], self.op_class)
for axis in range(-len(a.shape), len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class) self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)
...@@ -235,10 +235,11 @@ class TestCumOp(utt.InferShapeTester): ...@@ -235,10 +235,11 @@ class TestCumOp(utt.InferShapeTester):
def test_grad(self): def test_grad(self):
a = np.random.random((3, 5, 2)).astype(config.floatX) a = np.random.random((3, 5, 2)).astype(config.floatX)
utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None # Test default axis=None using cumsum/cumprod functions
utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod
for axis in range(-len(a.shape), len(a.shape)): for axis in range(len(a.shape)):
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4) utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)
utt.verify_grad(self.op_class(axis=axis, mode="mul"), [a], eps=4e-4) utt.verify_grad(self.op_class(axis=axis, mode="mul"), [a], eps=4e-4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论