提交 223ee154 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update Reshape C implementation

These changes remove the stride-based manual computation of the new shape, since those are potentially sensitive to broadcasted arrays with no strides.
上级 79961a62
...@@ -14,7 +14,7 @@ from aesara.tensor import _get_vector_length ...@@ -14,7 +14,7 @@ from aesara.tensor import _get_vector_length
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor from aesara.tensor.type import TensorType, int_dtypes, tensor
from aesara.tensor.var import TensorConstant, TensorVariable from aesara.tensor.var import TensorConstant
def register_shape_c_code(type, code, version=()): def register_shape_c_code(type, code, version=()):
...@@ -570,15 +570,11 @@ class Reshape(COp): ...@@ -570,15 +570,11 @@ class Reshape(COp):
if len(shp) != self.ndim: if len(shp) != self.ndim:
raise ValueError( raise ValueError(
( (
"shape argument to Reshape.perform has incorrect" "Shape argument to Reshape has incorrect"
f" length {len(shp)}" f" length: {len(shp)}, should be {self.ndim}"
f", should be {self.ndim}"
) )
) )
try: out[0] = np.reshape(x, shp)
out[0] = np.reshape(x, shp)
except Exception:
raise ValueError(f"Cannot reshape input of shape {x.shape} to shape {shp}")
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [False]] return [[True], [False]]
...@@ -669,44 +665,38 @@ class Reshape(COp): ...@@ -669,44 +665,38 @@ class Reshape(COp):
] ]
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable): x, shp = inputs
x, shp = inputs (z,) = outputs
(z,) = outputs fail = sub["fail"]
sdtype = node.inputs[1].type.dtype_specs()[1] params = sub["params"]
fail = sub["fail"] return f"""
params = sub["params"] assert (PyArray_NDIM({shp}) == 1);
return (
""" PyArray_Dims newshape;
assert (PyArray_NDIM(%(shp)s) == 1);
npy_intp new_dims[%(params)s->ndim]; if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{
PyArray_Dims newshape; {fail};
newshape.ptr = new_dims; }}
newshape.len = %(params)s->ndim;
for (int ii = 0; ii < %(params)s->ndim; ++ii) if ({params}->ndim != newshape.len) {{
{ PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length");
// -- We do not want an explicit cast here. the shp can be any PyDimMem_FREE(newshape.ptr);
// -- int* dtype. The compiler will explicitly upcast it, but {fail};
// -- will err if this will downcast. This could happen if the }}
// -- user pass an int64 dtype, but npy_intp endup being int32.
new_dims[ii] = ((%(sdtype)s*)( Py_XDECREF({z});
PyArray_BYTES(%(shp)s) + {z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
ii * PyArray_STRIDES(%(shp)s)[0]))[0];
} PyDimMem_FREE(newshape.ptr);
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER); if (!{z}) {{
if (!%(z)s) //The error message should have been set by PyArray_Newshape
{ {fail};
//The error message should have been set by PyArray_Newshape }}
%(fail)s; """
}
"""
% locals()
)
else:
raise NotImplementedError()
def reshape(x, newshape, ndim=None): def reshape(x, newshape, ndim=None):
......
...@@ -222,6 +222,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -222,6 +222,10 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
f(a_val, [7, 5]) f(a_val, [7, 5])
with pytest.raises(ValueError): with pytest.raises(ValueError):
f(a_val, [-1, -1]) f(a_val, [-1, -1])
with pytest.raises(
ValueError, match=".*Shape argument to Reshape has incorrect length.*"
):
f(a_val, [3, 4, 1])
def test_0(self): def test_0(self):
x = fvector("x") x = fvector("x")
...@@ -267,14 +271,14 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -267,14 +271,14 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
[admat], [Reshape(ndim)(admat, [-1, 4])], [admat_val], Reshape [admat], [Reshape(ndim)(admat, [-1, 4])], [admat_val], Reshape
) )
# enable when infer_shape is generalized: aivec = ivector()
# self._compile_and_check([admat, aivec], self._compile_and_check(
# [Reshape(ndim)(admat, aivec)], [admat, aivec], [Reshape(ndim)(admat, aivec)], [admat_val, [4, 3]], Reshape
# [admat_val, [4, 3]], Reshape) )
#
# self._compile_and_check([admat, aivec], self._compile_and_check(
# [Reshape(ndim)(admat, aivec)], [admat, aivec], [Reshape(ndim)(admat, aivec)], [admat_val, [4, -1]], Reshape
# [admat_val, [4, -1]], Reshape) )
adtens4 = dtensor4() adtens4 = dtensor4()
ndim = 4 ndim = 4
...@@ -287,14 +291,19 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin): ...@@ -287,14 +291,19 @@ class TestReshape(utt.InferShapeTester, utt.OptimizationTestMixin):
[adtens4], [Reshape(ndim)(adtens4, [1, 3, 10, 4])], [adtens4_val], Reshape [adtens4], [Reshape(ndim)(adtens4, [1, 3, 10, 4])], [adtens4_val], Reshape
) )
# enable when infer_shape is generalized: self._compile_and_check(
# self._compile_and_check([adtens4, aivec], [adtens4, aivec],
# [Reshape(ndim)(adtens4, aivec)], [Reshape(ndim)(adtens4, aivec)],
# [adtens4_val, [1, -1, 10, 4]], Reshape) [adtens4_val, [1, -1, 10, 4]],
# Reshape,
# self._compile_and_check([adtens4, aivec], )
# [Reshape(ndim)(adtens4, aivec)],
# [adtens4_val, [1, 3, 10, 4]], Reshape) self._compile_and_check(
[adtens4, aivec],
[Reshape(ndim)(adtens4, aivec)],
[adtens4_val, [1, 3, 10, 4]],
Reshape,
)
def test_shape_i_hash(): def test_shape_i_hash():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论