提交 001ddc38 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: don't waste time copying output before overwriting it

上级 8cdf8e62
...@@ -3537,6 +3537,16 @@ class BatchedDot(Op): ...@@ -3537,6 +3537,16 @@ class BatchedDot(Op):
_z, = out _z, = out
fail = sub["fail"] fail = sub["fail"]
# generate condition to detect noncontiguous arrays
def not_contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
return " || ".join([
" || ".join("{strides}[{i}] < 1 || {strides}[{i}] % type_size"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " && ".join("{strides}[{i}] != type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
# generate code to allocate output based on runtime input shapes # generate code to allocate output based on runtime input shapes
...@@ -3550,8 +3560,9 @@ class BatchedDot(Op): ...@@ -3550,8 +3560,9 @@ class BatchedDot(Op):
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s" z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims)) % (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims) z_shape = ", ".join(z_dims)
z_not_contiguous = not_contiguous(_z, z_ndim)
allocate = """ allocate = """
if ((NULL == %(_z)s) || !(%(z_shape_correct)s)) if ((NULL == %(_z)s) || !(%(z_shape_correct)s) || (%(z_not_contiguous)s))
{ {
npy_intp dims[%(z_ndim)s] = {%(z_shape)s}; npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s); Py_XDECREF(%(_z)s);
...@@ -3565,18 +3576,12 @@ class BatchedDot(Op): ...@@ -3565,18 +3576,12 @@ class BatchedDot(Op):
} }
""" % locals() """ % locals()
# generate code to reallocate inputs/output contiguously if necessary # code to reallocate inputs contiguously if necessary
contiguate = [] contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim), (_z, z_ndim)]: for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
strides = "PyArray_STRIDES(%s)" % var _not_contiguous = not_contiguous(var, ndim)
not_contiguous = " || ".join([
" || ".join("{strides}[{i}] < 1 || {strides}[{i}] % type_size"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " && ".join("{strides}[{i}] != type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
contiguate.append(""" contiguate.append("""
if (%(not_contiguous)s) { if (%(_not_contiguous)s) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s); PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!_copy) if (!_copy)
%(fail)s %(fail)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论