提交 001ddc38 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: don't waste time copying output before overwriting it

上级 8cdf8e62
......@@ -3537,6 +3537,16 @@ class BatchedDot(Op):
_z, = out
fail = sub["fail"]
# generate condition to detect noncontiguous arrays
def not_contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
return " || ".join([
" || ".join("{strides}[{i}] < 1 || {strides}[{i}] % type_size"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " && ".join("{strides}[{i}] != type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
# generate code to allocate output based on runtime input shapes
......@@ -3550,8 +3560,9 @@ class BatchedDot(Op):
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims)
z_not_contiguous = not_contiguous(_z, z_ndim)
allocate = """
if ((NULL == %(_z)s) || !(%(z_shape_correct)s))
if ((NULL == %(_z)s) || !(%(z_shape_correct)s) || (%(z_not_contiguous)s))
{
npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s);
......@@ -3565,18 +3576,12 @@ class BatchedDot(Op):
}
""" % locals()
# generate code to reallocate inputs/output contiguously if necessary
# code to reallocate inputs contiguously if necessary
contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim), (_z, z_ndim)]:
strides = "PyArray_STRIDES(%s)" % var
not_contiguous = " || ".join([
" || ".join("{strides}[{i}] < 1 || {strides}[{i}] % type_size"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " && ".join("{strides}[{i}] != type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_not_contiguous = not_contiguous(var, ndim)
contiguate.append("""
if (%(not_contiguous)s) {
if (%(_not_contiguous)s) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!_copy)
%(fail)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论