提交 551f2f0c authored 作者: Frederic's avatar Frederic

Add Join.c_code()

上级 9e84325c
...@@ -3538,6 +3538,36 @@ class Join(Op): ...@@ -3538,6 +3538,36 @@ class Join(Op):
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis), out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:]
l = len(tensors)
out, = outputs
fail = sub['fail']
code = """
PyObject* list = PyList_New(%(l)s);
""" % locals()
for i, inp in enumerate(tensors):
code += """
Py_INCREF(%(inp)s);
PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s);
""" % locals()
code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list,
((dtype_%(axis)s *)PyArray_DATA(%(axis)s))[0]);
Py_DECREF(list);
if(!%(out)s){
%(fail)s
}
""" % locals()
return code
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points[1:]: if None in eval_points[1:]:
return [None] return [None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论