提交 49f76da9 authored 作者: copilot-swe-agent[bot]'s avatar copilot-swe-agent[bot] 提交者: Ricardo Vieira

Implement axis=None raveling behavior symbolically in CumOp

上级 9b522a86
......@@ -41,21 +41,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
mode = op.mode
ndim = cast(TensorVariable, node.outputs[0]).ndim
if axis is not None:
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
reaxis_first_inv = tuple(np.argsort(reaxis_first))
if mode == "add":
if axis is None or ndim == 1:
if ndim == 1:
@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)
return np.cumsum(x, axis=axis)
else:
......@@ -75,11 +69,11 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
return res.transpose(reaxis_first_inv)
else:
if axis is None or ndim == 1:
if ndim == 1:
@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)
return np.cumprod(x, axis=axis)
else:
......@@ -96,7 +90,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
for m in range(1, x.shape[axis]):
res[m] = res[m - 1] * x_axis_first[m]
return res.transpose(reaxis_first)
return res.transpose(reaxis_first_inv)
return cumop
......
......@@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
mode = op.mode
def cumop(x):
if axis is None:
x = x.reshape(-1)
dim = 0
else:
dim = axis
if mode == "add":
return torch.cumsum(x, dim=dim)
return torch.cumsum(x, dim=axis)
else:
return torch.cumprod(x, dim=dim)
return torch.cumprod(x, dim=axis)
return cumop
......
import warnings
from collections.abc import Collection, Iterable
from textwrap import dedent
import numpy as np
from numpy.lib.array_utils import normalize_axis_index
......@@ -44,10 +45,10 @@ from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import LOCAL_BITWIDTH, NPY_RAVEL_AXIS, PYTHON_INT_BITWIDTH
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
class CpuContiguous(COp):
......@@ -290,33 +291,28 @@ class CumOp(COp):
__props__ = ("axis", "mode")
check_input = False
params_type = ParamsType(
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)
def __init__(self, axis: int | None = None, mode="add"):
def __init__(self, axis: int, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
if not (isinstance(axis, int) or axis is None):
raise TypeError("axis must be an integer or None.")
if not isinstance(axis, int):
raise TypeError(f"axis must be an integer, got {axis} of type {type(axis)}")
if axis < 0:
raise ValueError(f"axis must be non-negative, got {axis}")
self.axis = axis
self.mode = mode
@property
def c_axis(self) -> int:
if self.axis is None:
return NPY_RAVEL_AXIS
return self.axis
def make_node(self, x):
x = ptb.as_tensor_variable(x)
out_type = x.type()
if self.axis is None:
out_type = vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim or self.axis < -x.ndim:
raise ValueError(f"axis(={self.axis}) out of bounds")
if self.axis >= x.type.ndim:
raise ValueError(
f"axis(={self.axis}) out of bounds for variable {x} with {x.type.ndim} ndims"
)
return Apply(self, [x], [out_type])
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage):
x = inputs[0]
......@@ -326,21 +322,10 @@ class CumOp(COp):
else:
z[0] = np.cumprod(x, axis=self.axis)
def grad(self, inputs, output_gradients):
def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs
(gi,) = output_gradients
if self.axis is None:
if self.mode == "add":
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
)
reverse_slicing = [slice(None, None, None)] * gi.ndim
reverse_slicing[self.axis] = slice(None, None, -1)
reverse_slicing = tuple(reverse_slicing)
......@@ -357,9 +342,6 @@ class CumOp(COp):
)
def infer_shape(self, fgraph, node, shapes):
if self.axis is None and len(shapes[0]) > 1:
return [(prod(shapes[0]),)] # Flatten
return shapes
def c_code(self, node, name, inames, onames, sub):
......@@ -368,61 +350,43 @@ class CumOp(COp):
fail = sub["fail"]
params = sub["params"]
if self.axis is None:
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
else:
axis_code = f"int axis = {params}->c_axis;\n"
code = (
axis_code
+ f"""
#undef NPY_UF_DBG_TRACING
#define NPY_UF_DBG_TRACING 1
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_RAVEL_AXIS;
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
}}
return dedent(
f"""
int axis = {params}->axis;
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
if (!{z}){{ {fail} }};
}}
if (!{z})
{fail};
{{
PyObject * t = NULL;
if({params}->mode == MODE_ADD)
t = PyArray_CumSum(
{x}, axis,
PyArray_TYPE({x}), {z});
t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z});
else if({params}->mode == MODE_MUL)
t = PyArray_CumProd(
{x}, axis,
PyArray_TYPE({x}), {z});
t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z});
if (!t){{
{fail};
}}
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}}
"""
)
return code
def c_code_cache_version(self):
return (10,)
return (11,)
def __str__(self):
if self.mode == "add":
return f"Cumsum{{axis={self.axis}}}"
elif self.mode == "mul":
return f"Cumprod{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"
......@@ -443,6 +407,12 @@ def cumsum(x, axis=None):
.. versionadded:: 0.7
"""
x = ptb.as_tensor_variable(x)
if axis is None:
x = x.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="add")(x)
......@@ -463,6 +433,12 @@ def cumprod(x, axis=None):
.. versionadded:: 0.7
"""
x = ptb.as_tensor_variable(x)
if axis is None:
x = x.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="mul")(x)
......@@ -471,18 +447,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
"""Vectorize the CumOp to work on a batch of inputs."""
[original_x] = node.inputs
batch_ndim = batch_x.ndim - original_x.ndim
axis = op.axis
if axis is None and original_x.ndim == 1:
axis = 0
elif axis is not None:
axis = normalize_axis_index(op.axis, original_x.ndim)
if axis is None:
# Ravel all unbatched dimensions and perform CumOp on the last axis
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
else:
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
# op.axis is already normalized and non-negative
return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)
def diff(x, n=1, axis=-1):
......
......@@ -38,11 +38,6 @@ def test_Bartlett(val):
1,
"add",
),
(
(pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))),
-1,
"add",
),
(
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
0,
......@@ -53,11 +48,6 @@ def test_Bartlett(val):
1,
"add",
),
(
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
None,
"add",
),
(
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
0,
......@@ -68,11 +58,6 @@ def test_Bartlett(val):
1,
"mul",
),
(
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
None,
"mul",
),
],
)
def test_CumOp(val, axis, mode):
......
......@@ -5,39 +5,13 @@ import pytensor.tensor as pt
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize(
"dtype",
["float64", "int64"],
)
@pytest.mark.parametrize(
"axis",
[None, 1, (0,)],
)
@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, -1])
def test_pytorch_CumOp(axis, dtype):
"""Test PyTorch conversion of the `CumOp` `Op`."""
# Create a symbolic input for the first input of `CumOp`
a = pt.matrix("a", dtype=dtype)
# Create test value
test_value = np.arange(9, dtype=dtype).reshape((3, 3))
# Create the output variable
if isinstance(axis, tuple):
with pytest.raises(TypeError, match="axis must be an integer or None\\."):
out = pt.cumsum(a, axis=axis)
with pytest.raises(TypeError, match="axis must be an integer or None\\."):
out = pt.cumprod(a, axis=axis)
else:
out = pt.cumsum(a, axis=axis)
# Pass the inputs and outputs to the testing function
compare_pytorch_and_py([a], [out], [test_value])
# For the second mode of CumOp
out = pt.cumprod(a, axis=axis)
compare_pytorch_and_py([a], [out], [test_value])
outs = [pt.cumsum(a, axis=axis), pt.cumprod(a, axis=axis)]
compare_pytorch_and_py([a], outs, [test_value])
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
......
......@@ -195,7 +195,7 @@ class TestCumOp(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
self.op_class = CumOp
self.op = CumOp()
self.op = CumOp(axis=0)
def test_cum_op(self):
x = tensor3("x")
......@@ -226,8 +226,8 @@ class TestCumOp(utt.InferShapeTester):
x = tensor3("x")
a = np.random.random((3, 5, 2)).astype(config.floatX)
# Test axis=None
self._compile_and_check([x], [self.op(x)], [a], self.op_class)
# Test default axis=None
self._compile_and_check([x], [cumsum(x)], [a], self.op_class)
for axis in range(-len(a.shape), len(a.shape)):
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)
......@@ -235,10 +235,11 @@ class TestCumOp(utt.InferShapeTester):
def test_grad(self):
a = np.random.random((3, 5, 2)).astype(config.floatX)
utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None
utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None
# Test default axis=None using cumsum/cumprod functions
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod
for axis in range(-len(a.shape), len(a.shape)):
for axis in range(len(a.shape)):
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)
utt.verify_grad(self.op_class(axis=axis, mode="mul"), [a], eps=4e-4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论