提交 89d5366c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not introduce 0 strides for broadcastable dimensions in DimShuffle

Some poorly implemented BLAS operations don't handle them correctly
上级 bf628c97
......@@ -33,12 +33,17 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
npy_intp original_size = PyArray_SIZE(input);
npy_intp new_size = 1;
for (npy_intp i = 0; i < nd_out; ++i) {
// We set the strides of length 1 dimensions to PyArray_ITEMSIZE(input).
// The value is arbitrary, because there is never a next element.
// np.expand_dims(x, 0) and x[None] do different things here.
// I would prefer zero, but there are some poorly implemented BLAS operations
// That don't handle zero strides correctly. At least they won't fail because of DimShuffle.
if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(input)[new_order[i]];
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]];
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? PyArray_ITEMSIZE(input) : PyArray_STRIDES(input)[new_order[i]];
} else {
dimensions[i] = 1;
strides[i] = 0;
strides[i] = PyArray_ITEMSIZE(input);
}
new_size *= dimensions[i];
}
......
......@@ -185,14 +185,13 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
# as the broadcasted value; that way, we'll be able to tell that we're getting
# junk data from a poorly constructed array view.
x_val = np.broadcast_to(2039, (5000,))
expected_x_val = x_val[None]
for i in range(1):
inputs[0].storage[0] = x_val
thunk()
# Make sure it's a view of the original data
assert np.shares_memory(x_val, outputs[0].storage[0])
# Confirm the right strides
assert outputs[0].storage[0].strides == expected_x_val.strides
assert outputs[0].storage[0].strides[-1] == 0
# Confirm the broadcasted value in the output
assert np.array_equiv(outputs[0].storage[0], 2039)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论