提交 46c9f38a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove inplace_increment from cutils_ext.

上级 c60bd3b2
...@@ -14,189 +14,6 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')): ...@@ -14,189 +14,6 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
os.remove(os.path.join(config.compiledir, 'cutils_ext.so')) os.remove(os.path.join(config.compiledir, 'cutils_ext.so'))
def compile_cutils_code():
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32',
'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64',
'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160',
'complex192', 'complex512']]
inplace_map_template = """
#if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit,
PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
%(op)s
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
#endif
"""
floatadd = ("((%(type)s*)mit->dataptr)[0] = "
"(inc_or_set ? ((%(type)s*)mit->dataptr)[0] : 0)"
" + ((%(type)s*)it->dataptr)[0];")
complexadd = """
((%(type)s*)mit->dataptr)[0].real =
(inc_or_set ? ((%(type)s*)mit->dataptr)[0].real : 0)
+ ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag =
(inc_or_set ? ((%(type)s*)mit->dataptr)[0].imag : 0)
+ ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
def gen_binop(type, typen):
return """
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % dict(type=type, typen=typen)
fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join([gen_binop(type=t, typen=t.upper())
for t in types + complex_types]) + "NULL};\n")
def gen_num(typen):
return """
#if defined(%(typen)s)
%(typen)s,
#endif
""" % dict(type=type, typen=typen)
type_number_array = ("static int type_numbers[] = {" +
''.join([gen_num(typen=t.upper())
for t in types + complex_types]) + "-1000};")
code = ("""
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
PyArrayIterObject *, int inc_or_set);
""" + fns + fn_array + type_number_array + """
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op,
inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
if (mit->ait == NULL) {
return -1;
}
descr = PyArray_DESCR(mit->ait->ao);
Py_INCREF(descr);
arr = (PyArrayObject *)PyArray_FromAny(op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it, inc_or_set);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
static PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
int inc_or_set = 1;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i = 0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOO|i", &arg_a, &index,
&inc, &inc_or_set)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError,
"needs an ndarray as first argument");
return NULL;
}
a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL;
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL;
}
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
}
#endif
""")
return code
def compile_cutils(): def compile_cutils():
""" """
Do just the compilation of cutils_ext. Do just the compilation of cutils_ext.
...@@ -204,7 +21,6 @@ def compile_cutils(): ...@@ -204,7 +21,6 @@ def compile_cutils():
""" """
code = (""" code = ("""
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h"
#include "theano_mod_helper.h" #include "theano_mod_helper.h"
extern "C"{ extern "C"{
...@@ -226,18 +42,10 @@ def compile_cutils(): ...@@ -226,18 +42,10 @@ def compile_cutils():
int failure = fn(it); int failure = fn(it);
return Py_BuildValue("i", failure); return Py_BuildValue("i", failure);
}""") }
static PyMethodDef CutilsExtMethods[] = {
code += compile_cutils_code()
code += ("""static PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS, {"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."}, "Run a theano cthunk."},
#if NPY_API_VERSION >= 0x00000008
{"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
};""") };""")
if PY3: if PY3:
...@@ -256,7 +64,6 @@ def compile_cutils(): ...@@ -256,7 +64,6 @@ def compile_cutils():
PyMODINIT_FUNC PyMODINIT_FUNC
PyInit_cutils_ext(void) { PyInit_cutils_ext(void) {
import_array();
return PyModule_Create(&moduledef); return PyModule_Create(&moduledef);
} }
} }
...@@ -266,7 +73,6 @@ def compile_cutils(): ...@@ -266,7 +73,6 @@ def compile_cutils():
PyMODINIT_FUNC PyMODINIT_FUNC
initcutils_ext(void) initcutils_ext(void)
{ {
import_array();
(void) Py_InitModule("cutils_ext", CutilsExtMethods); (void) Py_InitModule("cutils_ext", CutilsExtMethods);
} }
} //extern C } //extern C
......
from __future__ import absolute_import, print_function, division
def inc_code():
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32',
'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64',
'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160',
'complex192', 'complex512']]
inplace_map_template = """
#if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit,
PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
%(op)s
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
#endif
"""
floatadd = ("((%(type)s*)mit->dataptr)[0] = "
"(inc_or_set ? ((%(type)s*)mit->dataptr)[0] : 0)"
" + ((%(type)s*)it->dataptr)[0];")
complexadd = """
((%(type)s*)mit->dataptr)[0].real =
(inc_or_set ? ((%(type)s*)mit->dataptr)[0].real : 0)
+ ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag =
(inc_or_set ? ((%(type)s*)mit->dataptr)[0].imag : 0)
+ ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
def gen_binop(type, typen):
return """
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % dict(type=type, typen=typen)
fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join([gen_binop(type=t, typen=t.upper())
for t in types + complex_types]) + "NULL};\n")
def gen_num(typen):
return """
#if defined(%(typen)s)
%(typen)s,
#endif
""" % dict(type=type, typen=typen)
type_number_array = ("static int type_numbers[] = {" +
''.join([gen_num(typen=t.upper())
for t in types + complex_types]) + "-1000};")
code = ("""
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
PyArrayIterObject *, int inc_or_set);
""" + fns + fn_array + type_number_array + """
static int
map_increment(PyArrayMapIterObject *mit, PyArrayObject *op,
inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
if (mit->ait == NULL) {
return -1;
}
descr = PyArray_DESCR(mit->ait->ao);
Py_INCREF(descr);
arr = (PyArrayObject *)PyArray_FromAny((PyObject *)op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it, inc_or_set);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
static int
inplace_increment(PyArrayObject *a, PyObject *index, PyArrayObject *inc,
int inc_or_set)
{
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i = 0;
PyArrayMapIterObject * mit;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return -1;
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return -1;
}
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return -1;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return 0;
fail:
Py_XDECREF(mit);
return -1;
}
#endif
""")
return code
...@@ -22,6 +22,8 @@ from theano.tensor.elemwise import DimShuffle ...@@ -22,6 +22,8 @@ from theano.tensor.elemwise import DimShuffle
from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice
from theano import config from theano import config
from .inc_code import inc_code
_logger = logging.getLogger("theano.tensor.subtensor") _logger = logging.getLogger("theano.tensor.subtensor")
# Do a lazy import of the sparse module # Do a lazy import of the sparse module
...@@ -1939,8 +1941,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -1939,8 +1941,7 @@ class AdvancedIncSubtensor1(Op):
NPY_ARRAY_ENSURECOPY, NULL)""" % locals() NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def c_support_code(self): def c_support_code(self):
from theano.gof.cutils import compile_cutils_code return inc_code()
return compile_cutils_code()
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
numpy_ver = [int(n) for n in np.__version__.split('.')[:2]] numpy_ver = [int(n) for n in np.__version__.split('.')[:2]]
...@@ -1972,17 +1973,14 @@ class AdvancedIncSubtensor1(Op): ...@@ -1972,17 +1973,14 @@ class AdvancedIncSubtensor1(Op):
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = %(copy_of_x)s; %(out)s = %(copy_of_x)s;
} }
PyObject *arglist = Py_BuildValue("OOOi",%(out)s, %(idx)s, %(y)s, %(inc_or_set)d); if (inplace_increment(%(out)s, (PyObject *)%(idx)s, %(y)s, %(inc_or_set)d)) {
rval = inplace_increment(NULL, arglist);
Py_XDECREF(arglist);
if (rval == NULL) {
%(fail)s; %(fail)s;
} }
Py_XDECREF(rval); Py_XDECREF(rval);
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
# TODO opt to make this inplace # TODO opt to make this inplace
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论