提交 0f3f2644 authored 作者: sebastien's avatar sebastien

Avoid code duplication

上级 fac49f83
......@@ -13,8 +13,7 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
os.remove(os.path.join(config.compiledir, 'cutils_ext.so'))
def compile_cutils():
"""Do just the compilation of cutils_ext"""
def compile_cutils_code():
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
......@@ -26,7 +25,7 @@ def compile_cutils():
inplace_map_template = """
#if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
......@@ -39,10 +38,10 @@ def compile_cutils():
#endif
"""
floatadd = "((%(type)s*)mit->dataptr)[0] = ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
floatadd = "((%(type)s*)mit->dataptr)[0] = inc_or_set * ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
complexadd = """
((%(type)s*)mit->dataptr)[0].real = ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
((%(type)s*)mit->dataptr)[0].real = inc_or_set * ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = inc_or_set * ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
......@@ -72,37 +71,13 @@ def compile_cutils():
"-1000};")
code = ("""
#include <Python.h>
#include "numpy/arrayobject.h"
extern "C"{
static PyObject *
run_cthunk(PyObject *self, PyObject *args)
{
PyObject *py_cthunk = NULL;
if(!PyArg_ParseTuple(args,"O",&py_cthunk))
return NULL;
if (!PyCObject_Check(py_cthunk)) {
PyErr_SetString(PyExc_ValueError,
"Argument to run_cthunk must be a PyCObject.");
return NULL;
}
void * ptr_addr = PyCObject_AsVoidPtr(py_cthunk);
int (*fn)(void*) = (int (*)(void*))(ptr_addr);
void* it = PyCObject_GetDesc(py_cthunk);
int failure = fn(it);
return Py_BuildValue("i", failure);
}
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int inc_or_set);
""" + fns + fn_array + type_number_array +
"""
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
......@@ -130,7 +105,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
return -1;
}
(*add_inplace)(mit, it);
(*add_inplace)(mit, it, inc_or_set);
Py_DECREF(arr);
Py_DECREF(it);
......@@ -142,19 +117,20 @@ static PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
int inc_or_set = 1;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i =0;
int i = 0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index,
&inc)) {
if (!PyArg_ParseTuple(args, "OOO|i", &arg_a, &index,
&inc, &inc_or_set)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
}
a = (PyArrayObject *) arg_a;
......@@ -187,7 +163,7 @@ inplace_increment(PyObject *dummy, PyObject *args)
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace) != 0) {
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
goto fail;
}
......@@ -202,9 +178,40 @@ fail:
return NULL;
}
#endif
""")
return code
static PyMethodDef CutilsExtMethods[] = {
def compile_cutils():
"""Do just the compilation of cutils_ext"""
code = ("""
#include <Python.h>
#include "numpy/arrayobject.h"
extern "C"{
static PyObject *
run_cthunk(PyObject *self, PyObject *args)
{
PyObject *py_cthunk = NULL;
if(!PyArg_ParseTuple(args,"O",&py_cthunk))
return NULL;
if (!PyCObject_Check(py_cthunk)) {
PyErr_SetString(PyExc_ValueError,
"Argument to run_cthunk must be a PyCObject.");
return NULL;
}
void * ptr_addr = PyCObject_AsVoidPtr(py_cthunk);
int (*fn)(void*) = (int (*)(void*))(ptr_addr);
void* it = PyCObject_GetDesc(py_cthunk);
int failure = fn(it);
return Py_BuildValue("i", failure);
}""")
code += compile_cutils_code()
code += ("""static PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."},
#if NPY_API_VERSION >= 0x00000008
......@@ -214,7 +221,6 @@ fail:
#endif
{NULL, NULL, 0, NULL} /* Sentinel */
};""")
if PY3:
# This is not the most efficient code, but it is written this way to
# highlight the changes needed to make 2.x code compile under python 3.
......
......@@ -1774,170 +1774,8 @@ class AdvancedIncSubtensor1(Op):
NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def c_support_code(self):
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160', 'complex192', 'complex512']]
inplace_map_template = """
#if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
%(op)s
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
#endif
"""
floatadd = "((%(type)s*)mit->dataptr)[0] = inc_or_set * ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
complexadd = """
((%(type)s*)mit->dataptr)[0].real = inc_or_set * ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = inc_or_set * ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join(["""
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"""NULL};
""")
type_number_array = ("static int type_numbers[] = {" +
''.join(["""
#if defined(%(typen)s)
%(typen)s,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};")
return ("""
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int inc_or_set);
""" + fns + fn_array + type_number_array +
"""
int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
if (mit->ait == NULL) {
return -1;
}
descr = PyArray_DESCR(mit->ait->ao);
Py_INCREF(descr);
arr = (PyArrayObject *)PyArray_FromAny(op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it, inc_or_set);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
int inc_or_set;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i = 0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOOi", &arg_a, &index,
&inc, &inc_or_set)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
}
a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL;
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL;
}
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
}
#endif
""")
from theano.gof.cutils import compile_cutils_code
return compile_cutils_code()
def c_code(self, node, name, input_names, output_names, sub):
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
......@@ -1982,6 +1820,7 @@ class AdvancedIncSubtensor1(Op):
# In Numpy, x[idx] += y doesn't work if the same index is present
# many times: it does it only once. Is it a bug? In any case, for
# this reason we implement our own 'inc' iteration.
if self.set_instead_of_inc:
x[idx] = y
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论