提交 ded89f6c authored 作者: James Bergstra's avatar James Bergstra

This commit fixes a bug in Dimshuffle.c_code. The bug is triggered when doing

a non-inplace DimShuffle of some view of underlying storage. The bug would result in the output strides being incorrect. The bug is fixed by doing the non-inplace copy part of the algorithm *before* calculating strides.
上级 96ba6945
...@@ -183,6 +183,8 @@ class DimShuffle(Op): ...@@ -183,6 +183,8 @@ class DimShuffle(Op):
storage[0] = numpy.asarray(res) #asarray puts scalars back into array storage[0] = numpy.asarray(res) #asarray puts scalars back into array
def c_code(self, node, name, (input,), (res,), sub): def c_code(self, node, name, (input,), (res,), sub):
basename = input + '__view_or_copy'
def statements(lst): def statements(lst):
return ';\n'.join(lst) + ';' return ';\n'.join(lst) + ';'
...@@ -194,47 +196,53 @@ class DimShuffle(Op): ...@@ -194,47 +196,53 @@ class DimShuffle(Op):
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}'] clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
#get the copy / view of the input depending on whether we're doing things inplace or not.
if self.inplace:
get_base = ['{ PyArrayObject * %(basename)s = %(input)s', 'Py_INCREF((PyObject*)%(basename)s)']
else:
get_base = [('{ PyArrayObject * %(basename)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
shape_statements = ['npy_intp dimensions[%i]'%nd_out] shape_statements = ['npy_intp dimensions[%i]'%nd_out]
shape_statements += [('dimensions['+str(i)+'] = %(input)s->dimensions['+str(o)+']') shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']')
if o != 'x' else if o != 'x' else
('dimensions['+str(i)+'] = 1') ('dimensions['+str(i)+'] = 1')
for i, o in enumerate(self.new_order)] for i, o in enumerate(self.new_order)]
strides_statements = ['npy_intp strides[%i]'%nd_out] strides_statements = ['npy_intp strides[%i]'%nd_out]
strides_statements += [('strides['+str(i)+'] = %(input)s->strides['+str(o)+']')
#set the strides of the non-broadcasted dimensions
strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']')
if o != 'x' else if o != 'x' else
('strides['+str(i)+'] = 0') ('strides['+str(i)+'] = 0')
for i, o in enumerate(self.new_order)] for i, o in enumerate(self.new_order)]
#set the strides of the broadcasted dimensions
strides_statements.append('if (strides['+str(nd_out)+'-1] == 0) strides['+str(nd_out)+'-1] = %(basename)s->descr->elsize')
for i in xrange(nd_out-2,-1, -1):
strides_statements.append("if (strides[%(i)s] == 0) strides[%(i)s] = strides[%(i)s+1] * dimensions[%(i)s+1]"%dict(i=str(i)))
if self.inplace:
get_output = ['{ PyArrayObject * base = %(input)s', 'Py_INCREF((PyObject*)base)']
else:
get_output = [('{ PyArrayObject * base = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
close_bracket = [ close_bracket = [
#create a new array, #create a new array,
('%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, ' ('%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, '
'' + str(nd_out) + ', dimensions, ' '' + str(nd_out) + ', dimensions, '
'PyArray_TYPE(base), strides, ' 'PyArray_TYPE(%(basename)s), strides, '
'base->data, base->descr->elsize, ' '%(basename)s->data, %(basename)s->descr->elsize, '
#borrow only the writable flag from the base #borrow only the writable flag from the base
# the NPY_OWNDATA flag will default to 0. # the NPY_OWNDATA flag will default to 0.
'PyArray_ISWRITEABLE(base), NULL)'), 'PyArray_ISWRITEABLE(%(basename)s), NULL)'),
#recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED #recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
'PyArray_UpdateFlags(%(res)s, NPY_UPDATE_ALL)', 'PyArray_UpdateFlags(%(res)s, NPY_UPDATE_ALL)',
#we are making a view in both inplace and non-inplace cases #we are making a view in both inplace and non-inplace cases
'%(res)s->base = (PyObject*)base', '%(res)s->base = (PyObject*)%(basename)s',
'}'] '}']
full_code = statements(check_input_nd full_code = statements(check_input_nd
+ clear_output + clear_output
+ get_base
+ shape_statements + shape_statements
+ strides_statements + strides_statements
+ get_output
+ close_bracket) + close_bracket)
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论