提交 7dc7d0b8 authored 作者: Reyhane Askari's avatar Reyhane Askari

changed Join C code to be pure C

上级 03f569f9
...@@ -4048,28 +4048,37 @@ class Join(Op): ...@@ -4048,28 +4048,37 @@ class Join(Op):
out, = outputs out, = outputs
fail = sub['fail'] fail = sub['fail']
adtype = node.inputs[0].type.dtype_specs()[1] adtype = node.inputs[0].type.dtype_specs()[1]
copy_to_list = []
for i, inp in enumerate(tensors):
copy_to_list.append(
"""Py_INCREF(%s);
PyList_SetItem(list, %s, (PyObject*)%s);"""
% (inp, i, inp))
copy_inputs_to_list = '\n'.join(copy_to_list)
n = len(tensors)
khar = "printf(\"tensors_lens_sum: %d\", tensors_lens_sum);"
code = """ code = """
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0]; int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int tensors_lens_sum = 0""" % locals() PyObject* list = PyList_New(%(l)s);
for i, inp in enumerate(tensors): %(copy_inputs_to_list)s
code += """ + PyArray_DIM(%(inp)s, axis) """ % locals() int tensors_lens_sum;
code += """;\n if(%(view)s != -1) {
tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis); tensors_lens_sum = 0;
if(%(view)s != -1 && tensors_lens_sum == 0){ for(int i=0; i < %(n)s; i++){
tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
}
%(khar)s
tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis);
}
if(%(view)s != -1 && tensors_lens_sum == 0) {
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
Py_INCREF(%(non_empty_tensor)s); Py_INCREF(%(non_empty_tensor)s);
%(out)s = %(non_empty_tensor)s; %(out)s = %(non_empty_tensor)s;
} }else{
else{
PyObject* list = PyList_New(%(l)s);
""" % locals()
for i, inp in enumerate(tensors):
code += """
Py_INCREF(%(inp)s);
PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s);
""" % locals()
code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis) //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int ndim = PyArray_NDIM(%(input_1)s); int ndim = PyArray_NDIM(%(input_1)s);
if( axis < -ndim ){ if( axis < -ndim ){
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论