提交 58a0320d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

C code for Alloc in all cases.

Also, use output memory even if not c-contiguous, instead of crashing.
上级 a05ada5d
...@@ -2849,40 +2849,47 @@ class Alloc(gof.Op): ...@@ -2849,40 +2849,47 @@ class Alloc(gof.Op):
out[0][...] = v # broadcast v to fill us up out[0][...] = v # broadcast v to fill us up
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
# TODO: use the elemwise code generator here vv = inp[0]
if python_all(node.inputs[0].broadcastable): ndim = len(inp[1:])
# filling with a scalar is a common use of alloc zz, = out
# that we can implement relatively easily fail = sub['fail']
vv = inp[0]
zz, = out code = """
fail = sub['fail'] npy_intp shape[%(ndim)s];
if node.outputs[0].ndim == 1: """ % dict(ndim=ndim)
N0 = inp[1]
return """ # Initialize shape
npy_intp N0 = ((dtype_%(N0)s*)%(N0)s->data)[0]; for i, shp_i in enumerate(inp[1:]):
dtype_%(vv)s vv; code += """
dtype_%(zz)s* zz; shape[%(i)s] = ((dtype_%(shp_i)s*) %(shp_i)s->data)[0];
if ((NULL == %(zz)s) || (%(zz)s->dimensions[0] != N0)) """ % dict(i=i, shp_i=shp_i)
{
if (%(zz)s) Py_XDECREF(%(zz)s); code += """
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1, int need_new_out = (NULL == %(zz)s);
&N0, type_num_%(vv)s); for (int i = 0; i < %(ndim)s; i++)
if(!%(zz)s) { need_new_out = (need_new_out
PyErr_SetString(PyExc_MemoryError, "alloc failed"); || (%(zz)s->dimensions[i] != shape[i]));
%(fail)s
} if (need_new_out)
} {
vv = ((dtype_%(vv)s*)%(vv)s->data)[0]; Py_XDECREF(%(zz)s);
zz = ((dtype_%(zz)s*)%(zz)s->data); %(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s,
assert (%(zz)s->strides[0] == sizeof(dtype_%(zz)s)); shape, type_num_%(vv)s);
for (int i = 0; i < N0; ++i) if (!%(zz)s)
{ {
zz[i] = vv; PyErr_SetString(PyExc_MemoryError, "alloc failed");
%(fail)s
} }
""" % locals() }
// This function takes care of broadcasting
PyArray_CopyInto(%(zz)s, %(vv)s);
""" % dict(vv=vv, ndim=ndim, zz=zz, fail=fail)
# else pretend this never happened return code
return super(Alloc, self).c_code(node, name, inp, out, sub)
def c_code_cache_version(self):
return (1,)
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [node.inputs[1:]] return [node.inputs[1:]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论