提交 897ac042 authored 作者: Frederic's avatar Frederic

Fix crash when check_input=True

上级 8b4f2b31
......@@ -782,22 +782,24 @@ class MakeVector(T.Op):
# So there will be (1 * nb_dtype) + ((nb len(inp) - 1 ))
# different c code with the following algo
out_shape = len(inp)
out_dtype = numpy.dtype(node.outputs[0].dtype).num
out_num = numpy.dtype(node.outputs[0].dtype).num
# don't use dtype_%(out)s as when check_input=False, it isn't defined.
out_dtype = node.outputs[0].type.dtype_specs()[1]
if len(inp) > 0:
assert self.dtype == node.inputs[0].dtype
out_dtype = 'PyArray_TYPE(%s)' % inp[0]
out_num = 'PyArray_TYPE(%s)' % inp[0]
ret = """
npy_intp dims[1];
dims[0] = %(out_shape)s;
if(!%(out)s || PyArray_DIMS(%(out)s)[0] != %(out_shape)s){
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_dtype)s, 0);
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_num)s, 0);
}
""" % locals()
for idx, i in enumerate(inp):
ret += """
*((dtype_%(out)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((dtype_%(out)s *) PyArray_DATA(%(i)s));
*((%(out_dtype)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((%(out_dtype)s *) PyArray_DATA(%(i)s));
""" % locals()
return ret
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论