提交 1075d838 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: negate contiguity test for less confusion

上级 001ddc38
...@@ -3537,13 +3537,13 @@ class BatchedDot(Op): ...@@ -3537,13 +3537,13 @@ class BatchedDot(Op):
_z, = out _z, = out
fail = sub["fail"] fail = sub["fail"]
# generate condition to detect noncontiguous arrays # generate contiguity condition
def not_contiguous(var, ndim): def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var strides = "PyArray_STRIDES(%s)" % var
return " || ".join([ return " && ".join([
" || ".join("{strides}[{i}] < 1 || {strides}[{i}] % type_size" " && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)), .format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " && ".join("{strides}[{i}] != type_size" "(%s)" % " || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(ndim)), .format(strides=strides, i=i) for i in range(ndim)),
]) ])
...@@ -3560,9 +3560,9 @@ class BatchedDot(Op): ...@@ -3560,9 +3560,9 @@ class BatchedDot(Op):
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s" z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims)) % (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims) z_shape = ", ".join(z_dims)
z_not_contiguous = not_contiguous(_z, z_ndim) z_contiguous = contiguous(_z, z_ndim)
allocate = """ allocate = """
if ((NULL == %(_z)s) || !(%(z_shape_correct)s) || (%(z_not_contiguous)s)) if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{ {
npy_intp dims[%(z_ndim)s] = {%(z_shape)s}; npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s); Py_XDECREF(%(_z)s);
...@@ -3579,9 +3579,9 @@ class BatchedDot(Op): ...@@ -3579,9 +3579,9 @@ class BatchedDot(Op):
# code to reallocate inputs contiguously if necessary # code to reallocate inputs contiguously if necessary
contiguate = [] contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]: for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_not_contiguous = not_contiguous(var, ndim) _contiguous = contiguous(var, ndim)
contiguate.append(""" contiguate.append("""
if (%(_not_contiguous)s) { if (!(%(_contiguous)s)) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s); PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!_copy) if (!_copy)
%(fail)s %(fail)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论