提交 58a0320d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

C code for Alloc in all cases.

Also, use output memory even if not c-contiguous, instead of crashing.
上级 a05ada5d
...@@ -2849,40 +2849,47 @@ class Alloc(gof.Op): ...@@ -2849,40 +2849,47 @@ class Alloc(gof.Op):
out[0][...] = v # broadcast v to fill us up out[0][...] = v # broadcast v to fill us up
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
# TODO: use the elemwise code generator here
if python_all(node.inputs[0].broadcastable):
# filling with a scalar is a common use of alloc
# that we can implement relatively easily
vv = inp[0] vv = inp[0]
ndim = len(inp[1:])
zz, = out zz, = out
fail = sub['fail'] fail = sub['fail']
if node.outputs[0].ndim == 1:
N0 = inp[1] code = """
return """ npy_intp shape[%(ndim)s];
npy_intp N0 = ((dtype_%(N0)s*)%(N0)s->data)[0]; """ % dict(ndim=ndim)
dtype_%(vv)s vv;
dtype_%(zz)s* zz; # Initialize shape
if ((NULL == %(zz)s) || (%(zz)s->dimensions[0] != N0)) for i, shp_i in enumerate(inp[1:]):
code += """
shape[%(i)s] = ((dtype_%(shp_i)s*) %(shp_i)s->data)[0];
""" % dict(i=i, shp_i=shp_i)
code += """
int need_new_out = (NULL == %(zz)s);
for (int i = 0; i < %(ndim)s; i++)
need_new_out = (need_new_out
|| (%(zz)s->dimensions[i] != shape[i]));
if (need_new_out)
{
Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s,
shape, type_num_%(vv)s);
if (!%(zz)s)
{ {
if (%(zz)s) Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
&N0, type_num_%(vv)s);
if(!%(zz)s) {
PyErr_SetString(PyExc_MemoryError, "alloc failed"); PyErr_SetString(PyExc_MemoryError, "alloc failed");
%(fail)s %(fail)s
} }
} }
vv = ((dtype_%(vv)s*)%(vv)s->data)[0];
zz = ((dtype_%(zz)s*)%(zz)s->data);
assert (%(zz)s->strides[0] == sizeof(dtype_%(zz)s));
for (int i = 0; i < N0; ++i)
{
zz[i] = vv;
}
""" % locals()
# else pretend this never happened // This function takes care of broadcasting
return super(Alloc, self).c_code(node, name, inp, out, sub) PyArray_CopyInto(%(zz)s, %(vv)s);
""" % dict(vv=vv, ndim=ndim, zz=zz, fail=fail)
return code
def c_code_cache_version(self):
return (1,)
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [node.inputs[1:]] return [node.inputs[1:]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论