提交 223ee154 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update Reshape C implementation

These changes remove the stride-based manual computation of the new shape, since those are potentially sensitive to broadcasted arrays with no strides.
上级 79961a62
......@@ -14,7 +14,7 @@ from aesara.tensor import _get_vector_length
from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor
from aesara.tensor.var import TensorConstant, TensorVariable
from aesara.tensor.var import TensorConstant
def register_shape_c_code(type, code, version=()):
......@@ -570,15 +570,11 @@ class Reshape(COp):
if len(shp) != self.ndim:
raise ValueError(
(
"shape argument to Reshape.perform has incorrect"
f" length {len(shp)}"
f", should be {self.ndim}"
"Shape argument to Reshape has incorrect"
f" length: {len(shp)}, should be {self.ndim}"
)
)
try:
out[0] = np.reshape(x, shp)
except Exception:
raise ValueError(f"Cannot reshape input of shape {x.shape} to shape {shp}")
def connection_pattern(self, node):
return [[True], [False]]
......@@ -669,44 +665,38 @@ class Reshape(COp):
]
def c_code_cache_version(self):
return (8,)
return (9,)
def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable):
x, shp = inputs
(z,) = outputs
sdtype = node.inputs[1].type.dtype_specs()[1]
fail = sub["fail"]
params = sub["params"]
return (
"""
assert (PyArray_NDIM(%(shp)s) == 1);
npy_intp new_dims[%(params)s->ndim];
return f"""
assert (PyArray_NDIM({shp}) == 1);
PyArray_Dims newshape;
newshape.ptr = new_dims;
newshape.len = %(params)s->ndim;
for (int ii = 0; ii < %(params)s->ndim; ++ii)
{
// -- We do not want an explicit cast here. the shp can be any
// -- int* dtype. The compiler will explicitly upcast it, but
// -- will err if this will downcast. This could happen if the
// -- user pass an int64 dtype, but npy_intp endup being int32.
new_dims[ii] = ((%(sdtype)s*)(
PyArray_BYTES(%(shp)s) +
ii * PyArray_STRIDES(%(shp)s)[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER);
if (!%(z)s)
{
if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{
{fail};
}}
if ({params}->ndim != newshape.len) {{
PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length");
PyDimMem_FREE(newshape.ptr);
{fail};
}}
Py_XDECREF({z});
{z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
PyDimMem_FREE(newshape.ptr);
if (!{z}) {{
//The error message should have been set by PyArray_Newshape
%(fail)s;
}
{fail};
}}
"""
% locals()
)
else:
raise NotImplementedError()
def reshape(x, newshape, ndim=None):
......
......@@ -222,6 +222,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(a_val, [7, 5])
with pytest.raises(ValueError):
f(a_val, [-1, -1])
with pytest.raises(
ValueError, match=".*Shape argument to Reshape has incorrect length.*"
):
f(a_val, [3, 4, 1])
def test_0(self):
x = fvector("x")
......@@ -267,14 +271,14 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
[admat], [Reshape(ndim)(admat, [-1, 4])], [admat_val], Reshape
)
# enable when infer_shape is generalized:
# self._compile_and_check([admat, aivec],
# [Reshape(ndim)(admat, aivec)],
# [admat_val, [4, 3]], Reshape)
#
# self._compile_and_check([admat, aivec],
# [Reshape(ndim)(admat, aivec)],
# [admat_val, [4, -1]], Reshape)
aivec = ivector()
self._compile_and_check(
[admat, aivec], [Reshape(ndim)(admat, aivec)], [admat_val, [4, 3]], Reshape
)
self._compile_and_check(
[admat, aivec], [Reshape(ndim)(admat, aivec)], [admat_val, [4, -1]], Reshape
)
adtens4 = dtensor4()
ndim = 4
......@@ -287,14 +291,19 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
[adtens4], [Reshape(ndim)(adtens4, [1, 3, 10, 4])], [adtens4_val], Reshape
)
# enable when infer_shape is generalized:
# self._compile_and_check([adtens4, aivec],
# [Reshape(ndim)(adtens4, aivec)],
# [adtens4_val, [1, -1, 10, 4]], Reshape)
#
# self._compile_and_check([adtens4, aivec],
# [Reshape(ndim)(adtens4, aivec)],
# [adtens4_val, [1, 3, 10, 4]], Reshape)
self._compile_and_check(
[adtens4, aivec],
[Reshape(ndim)(adtens4, aivec)],
[adtens4_val, [1, -1, 10, 4]],
Reshape,
)
self._compile_and_check(
[adtens4, aivec],
[Reshape(ndim)(adtens4, aivec)],
[adtens4_val, [1, 3, 10, 4]],
Reshape,
)
def test_shape_i_hash():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论