提交 5dcf6048 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Revert regression in Reshape C-impl speed

This was caused by 223ee154, which used the generic `PyArray_IntpConverter` to convert the shape numpy vector into a simple C-array for the Reshape operation. There seems to be no need for this change as the strides were correctly used Profiling suggests the previous changes caused a 7.5x slowdown. The benchmark detects only a 2.3x slowdown due to the PyTensor call overhead.
上级 53763f5d
...@@ -16,7 +16,6 @@ from pytensor.graph.type import HasShape ...@@ -16,7 +16,6 @@ from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.scalar import int32
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.elemwise import get_normalized_batch_axes
...@@ -628,14 +627,11 @@ class Reshape(COp): ...@@ -628,14 +627,11 @@ class Reshape(COp):
check_input = False check_input = False
__props__ = ("ndim",) __props__ = ("ndim",)
params_type = ParamsType(ndim=int32)
# name does not participate because it doesn't affect computations
def __init__(self, ndim, name=None): def __init__(self, ndim):
self.ndim = int(ndim) self.ndim = int(ndim)
if ndim < 0: if ndim < 0:
raise ValueError("The output dimensions after reshape must be 0 or greater") raise ValueError("The output dimensions after reshape must be 0 or greater")
assert name is None, "name attribute for Reshape has been deprecated"
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}{{{self.ndim}}}" return f"{self.__class__.__name__}{{{self.ndim}}}"
...@@ -795,33 +791,32 @@ class Reshape(COp): ...@@ -795,33 +791,32 @@ class Reshape(COp):
] ]
def c_code_cache_version(self): def c_code_cache_version(self):
return (9,) return (10,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
x, shp = inputs x, shp = inputs
shp_dtype = node.inputs[1].type.dtype_specs()[1]
(z,) = outputs (z,) = outputs
fail = sub["fail"] fail = sub["fail"]
params = sub["params"] ndim = self.ndim
return f""" return f"""
assert (PyArray_NDIM({shp}) == 1); assert (PyArray_NDIM({shp}) == 1);
PyArray_Dims newshape; // Unpack shape into new_dims
npy_intp new_dims[{ndim}];
if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{ for (int ii = 0; ii < {ndim}; ++ii)
{fail}; {{
new_dims[ii] = (({shp_dtype}*)(PyArray_BYTES({shp}) + ii * PyArray_STRIDES({shp})[0]))[0];
}} }}
if ({params}->ndim != newshape.len) {{ PyArray_Dims newshape;
PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length"); newshape.len = {ndim};
PyDimMem_FREE(newshape.ptr); newshape.ptr = new_dims;
{fail};
}}
Py_XDECREF({z}); Py_XDECREF({z});
{z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER); {z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
PyDimMem_FREE(newshape.ptr);
if (!{z}) {{ if (!{z}) {{
//The error message should have been set by PyArray_Newshape //The error message should have been set by PyArray_Newshape
{fail}; {fail};
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
from pytensor import Mode, function, grad from pytensor import In, Mode, Out, function, grad
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.basic import Variable, equal_computations
...@@ -12,7 +12,7 @@ from pytensor.graph.replace import clone_replace, vectorize_node ...@@ -12,7 +12,7 @@ from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant, stack from pytensor.tensor.basic import MakeVector, arange, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Reshape, Reshape,
...@@ -373,6 +373,43 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -373,6 +373,43 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
): ):
reshape(x2, (6, 3, 99)) reshape(x2, (6, 3, 99))
def test_shape_strides(self):
# Directly test the concern behind commit 223ee1548574b6bb8e73611ed605a97e29f13e7b
x = arange(8)
shape = vector("shape", dtype=int, shape=(3,))
fn = function([shape], x.reshape(shape))
# Empty strides
test_shape = np.broadcast_to(np.array(2), (3,))
assert test_shape.strides == (0,)
np.testing.assert_array_equal(
fn(test_shape),
np.arange(8).reshape(test_shape),
)
# Negative non-contiguous strides
test_shape = np.array([0, 4, 0, 2, 0, 1])[::-2]
assert np.all(test_shape == (1, 2, 4))
assert test_shape.strides == (-16,)
np.testing.assert_array_equal(
fn(test_shape),
np.arange(8).reshape(test_shape),
)
def test_benchmark(self, benchmark):
x = tensor3("x")
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
y1 = x.reshape((6, 4))
y2 = x.reshape((2, 12))
y3 = x.reshape((-1,))
# Borrow to avoid deepcopy overhead
reshape_fn = pytensor.function(
[In(x, borrow=True)],
[Out(y1, borrow=True), Out(y2, borrow=True), Out(y3, borrow=True)],
)
reshape_fn.trust_input = True
benchmark(reshape_fn, x_val)
def test_shape_i_hash(): def test_shape_i_hash():
assert isinstance(Shape_i(np.int64(1)).__hash__(), int) assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论