提交 37848106 authored 作者: John Salvatier's avatar John Salvatier

support lots of types

上级 de444dd9
...@@ -12,7 +12,58 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')): ...@@ -12,7 +12,58 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def compile_cutils(): def compile_cutils():
"""Do just the compilation of cutils_ext""" """Do just the compilation of cutils_ext"""
code = """
types = ['npy_'+t for t in ['int8', 'int16', 'int32', 'int64', 'int128', 'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'float16', 'float32', 'float64', 'float80', 'float96', 'float128', 'float256'] ]
complex_types = ['npy_'+t for t in ['complex32', 'complex64', 'complex128', 'complex160', 'complex192', 'complex512'] ]
inplace_map_template = """
#if defined({typen})
static void {type}_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{{
int index = mit->size;
while (index--) {{
{op}
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}}
}}
#endif
"""
floatadd = "(({type}*)mit->dataptr)[0] = (({type}*)mit->dataptr)[0] + (({type}*)it->dataptr)[0];"
complexadd = """
(({type}*)mit->dataptr)[0].real = (({type}*)mit->dataptr)[0].real + (({type}*)it->dataptr)[0].real;
(({type}*)mit->dataptr)[0].imag = (({type}*)mit->dataptr)[0].imag + (({type}*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template.format(type = t, typen = t.upper(), op = floatadd.format(type = t)) for t in types] +
[inplace_map_template.format(type = t, typen = t.upper(), op = complexadd.format(type = t)) for t in complex_types])
fn_array = ("inplace_map_binop addition_funcs[] = {" +
''.join(["""
#if defined({typen})
{type}_inplace_add,
#endif
""".format(type = t, typen = t.upper()) for t in types+complex_types]) +
"""NULL};
""")
type_number_array = ("int type_numbers[] = {" +
''.join(["""
#if defined({typen})
{typen},
#endif
""".format(type = t, typen = t.upper()) for t in types+complex_types]) +
"-1000};")
code = ("""
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
...@@ -38,29 +89,10 @@ def compile_cutils(): ...@@ -38,29 +89,10 @@ def compile_cutils():
} }
#if NPY_API_VERSION >= 0x00000008 #if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *); typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
""" + fns + fn_array + type_number_array +
static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{
int index = mit->size;
while (index--) {
((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0];
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
inplace_map_binop addition_funcs[] = {
npy_float64_inplace_add,
NULL};
int type_numbers[] = {
NPY_FLOAT64,
-1000};
"""
static int static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace) map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
{ {
...@@ -206,7 +238,7 @@ fail: ...@@ -206,7 +238,7 @@ fail:
(void) Py_InitModule("cutils_ext", CutilsExtMethods); (void) Py_InitModule("cutils_ext", CutilsExtMethods);
} }
} //extern C } //extern C
""" """)
import cmodule import cmodule
loc = os.path.join(config.compiledir, 'cutils_ext') loc = os.path.join(config.compiledir, 'cutils_ext')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论