提交 aa40fd06 authored 作者: John Salvatier's avatar John Salvatier

added imports for compiled version

上级 325a73a2
......@@ -14,6 +14,8 @@ def compile_cutils():
"""Do just the compilation of cutils_ext"""
code = """
#include <Python.h>
#include "numpy/arrayobject.h"
extern "C"{
static PyObject *
run_cthunk(PyObject *self, PyObject *args)
......@@ -35,10 +37,11 @@ def compile_cutils():
return Py_BuildValue("i", failure);
}
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{
static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{
int index = mit->size;
while (index--) {
((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0];
......@@ -46,21 +49,21 @@ def compile_cutils():
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
}
inplace_map_binop addition_funcs[] = {
npy_float64_inplace_add,
NULL};
inplace_map_binop addition_funcs[] = {
npy_float64_inplace_add,
NULL};
int type_numbers[] = {
NPY_FLOAT64,
-1000};
int type_numbers[] = {
NPY_FLOAT64,
-1000};
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
{
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
......@@ -82,9 +85,9 @@ def compile_cutils():
}
}
}
if ((it = (PyArrayIterObject *)\
PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd)) == NULL) {
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
return -1;
......@@ -95,12 +98,12 @@ def compile_cutils():
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
}
static PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
static PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
......@@ -156,19 +159,22 @@ def compile_cutils():
Py_INCREF(Py_None);
return Py_None;
fail:
fail:
Py_XDECREF(mit);
return NULL;
}
}
#endif
static PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."},
{"inplace_increment", inplace_increment, METH_VARARGS,
"."}
#if NPY_API_VERSION >= 0x00000008
{"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */
};"""
if PY3:
......@@ -201,6 +207,7 @@ def compile_cutils():
}
"""
import cmodule
loc = os.path.join(config.compiledir, 'cutils_ext')
if not os.path.exists(loc):
os.mkdir(loc)
......
......@@ -24,6 +24,8 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext
# We use these exceptions as well.
from theano.scalar import ComplexError, IntegerDivisionError
import theano.scalar.sharedvar
......@@ -7139,6 +7141,16 @@ class AdvancedIncSubtensor1(Op):
if self.set_instead_of_inc:
x[idx] = y
else:
try :
from cutils_ext.cutils_ext import inplace_increment as increment
except ImportError:
increment = self.inplace_increment1d_slow
increment(x,idx, y)
out[0] = x
def inplace_increment1d_slow(self, x, idx, y):
# If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`.
......@@ -7150,7 +7162,6 @@ class AdvancedIncSubtensor1(Op):
else:
for i in idx:
x[i] += y
out[0] = x
def infer_shape(self, node, ishapes):
x, y, ilist = ishapes
......@@ -7399,17 +7410,16 @@ class AdvancedIncSubtensor(Op):
out, = out_
if not self.inplace:
out[0] = inputs[0].copy()
else:
raise NotImplementedError('In place computation is not'
' implemented')
if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1]
else:
increment = None
try :
increment = gof.cutils_ext.inplace_increment
except:
raise NotImplementedError("Couldn't find
inplace_increment, update numpy.")
from cutils_ext.cutils_ext import inplace_increment as increment
except ImportError:
raise NotImplementedError('Did not find inplace_increment.'
'Update numpy?')
increment(out[0], tuple(inputs[2:]), inputs[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论