提交 ff5df65f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Revert regression in DimShuffle C-impl speed

Introduced in e593b0ac due to a bug when inputs had zero-strides. The bug can be fixed just by removing a block that assumed some `full`/`broadcasting` behavior by the operation, but this is not happening with DimShuffle.
上级 5dcf6048
#section support_code_apply
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
PARAMS_TYPE *params) {
// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject *_input;
if (*res)
Py_XDECREF(*res);
if (params->inplace) {
_input = input;
Py_INCREF((PyObject *)_input);
} else {
_input = (PyArrayObject *)PyArray_FromAny(
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
NULL);
}
PyArray_Dims permute;
if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) {
return 1;
}
/*
res = res.transpose(self.transposition)
*/
PyArrayObject *transposed_input =
(PyArrayObject *)PyArray_Transpose(_input, &permute);
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PARAMS_TYPE *params) {
npy_int64* new_order;
npy_intp nd_in;
npy_intp nd_out;
npy_intp* dimensions;
npy_intp* strides;
// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject *_input;
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
return 1;
}
new_order = (npy_int64*) PyArray_DATA(params->_new_order);
nd_in = (npy_intp)(params->input_ndim);
nd_out = PyArray_SIZE(params->_new_order);
Py_DECREF(_input);
if (PyArray_NDIM(input) != nd_in) {
PyErr_SetString(PyExc_NotImplementedError, "DimShuffle: Input has less dimensions than expected.");
return 1;
}
PyDimMem_FREE(permute.ptr);
// Compute new dimensions and strides
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
if (dimensions == NULL || strides == NULL) {
PyErr_NoMemory();
free(dimensions);
free(strides);
return 1;
};
npy_intp original_size = PyArray_SIZE(_input);
npy_intp new_size = 1;
for (npy_intp i = 0; i < nd_out; ++i) {
if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(_input)[new_order[i]];
strides[i] = PyArray_DIMS(_input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(_input)[new_order[i]];
} else {
dimensions[i] = 1;
strides[i] = 0;
}
new_size *= dimensions[i];
}
npy_intp *res_shape = PyArray_DIMS(transposed_input);
npy_intp N_shuffle = PyArray_SIZE(params->shuffle);
npy_intp N_augment = PyArray_SIZE(params->augment);
npy_intp N = N_augment + N_shuffle;
npy_intp *_reshape_shape = PyDimMem_NEW(N);
if (original_size != new_size) {
PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one.");
free(dimensions);
free(strides);
return 1;
}
if (_reshape_shape == NULL) {
PyErr_NoMemory();
return 1;
}
if (*res)
Py_XDECREF(*res);
/*
shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
*/
npy_intp aug_idx = 0;
int res_idx = 0;
for (npy_intp i = 0; i < N; i++) {
if (aug_idx < N_augment &&
i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) {
_reshape_shape[i] = 1;
aug_idx++;
if (params->inplace) {
_input = input;
Py_INCREF((PyObject*)_input);
} else {
_reshape_shape[i] = res_shape[res_idx];
res_idx++;
_input = (PyArrayObject *)PyArray_FromAny(
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
NULL);
}
}
PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N};
/* res = res.reshape(shape) */
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
NPY_CORDER);
Py_DECREF(transposed_input);
// Create the new array.
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
PyArray_TYPE(_input), strides,
PyArray_DATA(_input), PyArray_ITEMSIZE(_input),
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(_input)),
NULL);
if (*res == NULL) {
free(dimensions);
free(strides);
return 1;
}
PyDimMem_FREE(reshape_shape.ptr);
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
if (!*res) {
return 1;
}
// we are making a view in both inplace and non-inplace cases
PyArray_SetBaseObject(*res, (PyObject*)_input);
return 0;
}
free(strides);
free(dimensions);
return 0;
}
\ No newline at end of file
......@@ -21,7 +21,7 @@ from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast
from pytensor.scalar.basic import int64, transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
......@@ -121,10 +121,9 @@ class DimShuffle(ExternalCOp):
@property
def params_type(self):
return ParamsType(
shuffle=lvector,
augment=lvector,
transposition=lvector,
_new_order=lvector,
inplace=scalar_bool,
input_ndim=int64,
)
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
......@@ -135,6 +134,7 @@ class DimShuffle(ExternalCOp):
self.input_ndim = input_ndim
self.new_order = tuple(new_order)
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]
self.inplace = True
for i, j in enumerate(new_order):
......@@ -231,10 +231,15 @@ class DimShuffle(ExternalCOp):
def perform(self, node, inp, out):
(res,) = inp
(storage,) = out
if not isinstance(res, np.ndarray | np.memmap):
raise TypeError(res)
# This C-like impl is very slow in Python compared to transpose+reshape
# new_order = self._new_order
# old_shape = inp.shape
# old_strides = inp.strides
# res = as_strided(
# shape = [1 if i == -1 else old_shape[i] for i in new_order],
# strides=[0 if i == -1 else old_strides[i] for i in new_order],
# )
# Put dropped axis at end
res = res.transpose(self.transposition)
......@@ -248,7 +253,7 @@ class DimShuffle(ExternalCOp):
if not self.inplace:
res = np.copy(res)
storage[0] = np.asarray(res)
out[0][0] = res
def infer_shape(self, fgraph, node, shapes):
(ishp,) = shapes
......
import itertools
import math
import re
import tracemalloc
......@@ -10,6 +11,7 @@ import pytensor
import pytensor.scalar as ps
import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor import In, Out
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
......@@ -35,6 +37,7 @@ from pytensor.tensor.type import (
matrix,
scalar,
tensor,
tensor3,
vector,
vectors,
)
......@@ -158,11 +161,14 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
# as the broadcasted value; that way, we'll be able to tell that we're getting
# junk data from a poorly constructed array view.
x_val = np.broadcast_to(2039, (5000,))
for i in range(1000):
expected_x_val = x_val[None]
for i in range(1):
inputs[0].storage[0] = x_val
thunk()
# Make sure it's a view of the original data
assert np.shares_memory(x_val, outputs[0].storage[0])
# Confirm the right strides
assert outputs[0].storage[0].strides == expected_x_val.strides
# Confirm the broadcasted value in the output
assert np.array_equiv(outputs[0].storage[0], 2039)
......@@ -212,6 +218,24 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
with pytest.raises(TypeError, match="input_ndim must be an integer"):
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
def test_benchmark(self, benchmark):
x = tensor3("x")
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
ys += [
x[None],
x[:, None],
x[:, :, None],
x[:, :, :, None],
]
# Borrow to avoid deepcopy overhead
fn = pytensor.function(
[In(x, borrow=True)],
[Out(y, borrow=True) for y in ys],
)
fn.trust_input = True
benchmark(fn, x_val)
class TestBroadcast:
# this is to allow other types to reuse this class to test their ops
......
......@@ -480,10 +480,7 @@ class TestSqueeze(utt.InferShapeTester):
assert f([0]) == 0
# Test that we cannot squeeze dimensions whose length is greater than 1
with pytest.raises(
ValueError,
match="cannot reshape array of size 3 into shape ()",
):
with pytest.raises(ValueError):
f([0, 1, 2])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论