提交 aa40fd06 authored 作者: John Salvatier's avatar John Salvatier

added imports for compiled version

上级 325a73a2
...@@ -14,6 +14,8 @@ def compile_cutils(): ...@@ -14,6 +14,8 @@ def compile_cutils():
"""Do just the compilation of cutils_ext""" """Do just the compilation of cutils_ext"""
code = """ code = """
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h"
extern "C"{ extern "C"{
static PyObject * static PyObject *
run_cthunk(PyObject *self, PyObject *args) run_cthunk(PyObject *self, PyObject *args)
...@@ -35,10 +37,11 @@ def compile_cutils(): ...@@ -35,10 +37,11 @@ def compile_cutils():
return Py_BuildValue("i", failure); return Py_BuildValue("i", failure);
} }
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *); #if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it) static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{ {
int index = mit->size; int index = mit->size;
while (index--) { while (index--) {
((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0]; ((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0];
...@@ -46,21 +49,21 @@ def compile_cutils(): ...@@ -46,21 +49,21 @@ def compile_cutils():
PyArray_MapIterNext(mit); PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it); PyArray_ITER_NEXT(it);
} }
} }
inplace_map_binop addition_funcs[] = { inplace_map_binop addition_funcs[] = {
npy_float64_inplace_add, npy_float64_inplace_add,
NULL}; NULL};
int type_numbers[] = { int type_numbers[] = {
NPY_FLOAT64, NPY_FLOAT64,
-1000}; -1000};
static int static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace) map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
{ {
PyArrayObject *arr = NULL; PyArrayObject *arr = NULL;
PyArrayIterObject *it; PyArrayIterObject *it;
PyArray_Descr *descr; PyArray_Descr *descr;
...@@ -82,9 +85,9 @@ def compile_cutils(): ...@@ -82,9 +85,9 @@ def compile_cutils():
} }
} }
} }
it = (PyArrayIterObject*)
if ((it = (PyArrayIterObject *)\ PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd)) == NULL) { if (it == NULL) {
Py_DECREF(arr); Py_DECREF(arr);
return -1; return -1;
...@@ -95,12 +98,12 @@ def compile_cutils(): ...@@ -95,12 +98,12 @@ def compile_cutils():
Py_DECREF(arr); Py_DECREF(arr);
Py_DECREF(it); Py_DECREF(it);
return 0; return 0;
} }
static PyObject * static PyObject *
inplace_increment(PyObject *dummy, PyObject *args) inplace_increment(PyObject *dummy, PyObject *args)
{ {
PyObject *arg_a = NULL, *index=NULL, *inc=NULL; PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a; PyArrayObject *a;
inplace_map_binop add_inplace = NULL; inplace_map_binop add_inplace = NULL;
...@@ -156,19 +159,22 @@ def compile_cutils(): ...@@ -156,19 +159,22 @@ def compile_cutils():
Py_INCREF(Py_None); Py_INCREF(Py_None);
return Py_None; return Py_None;
fail: fail:
Py_XDECREF(mit); Py_XDECREF(mit);
return NULL; return NULL;
} }
#endif
static PyMethodDef CutilsExtMethods[] = { static PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS, {"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."}, "Run a theano cthunk."},
{"inplace_increment", inplace_increment, METH_VARARGS, #if NPY_API_VERSION >= 0x00000008
"."} {"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
};""" };"""
if PY3: if PY3:
...@@ -201,6 +207,7 @@ def compile_cutils(): ...@@ -201,6 +207,7 @@ def compile_cutils():
} }
""" """
import cmodule
loc = os.path.join(config.compiledir, 'cutils_ext') loc = os.path.join(config.compiledir, 'cutils_ext')
if not os.path.exists(loc): if not os.path.exists(loc):
os.mkdir(loc) os.mkdir(loc)
......
...@@ -24,6 +24,8 @@ from theano import compile, printing ...@@ -24,6 +24,8 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext
# We use these exceptions as well. # We use these exceptions as well.
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
import theano.scalar.sharedvar import theano.scalar.sharedvar
...@@ -7139,6 +7141,16 @@ class AdvancedIncSubtensor1(Op): ...@@ -7139,6 +7141,16 @@ class AdvancedIncSubtensor1(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
x[idx] = y x[idx] = y
else: else:
try :
from cutils_ext.cutils_ext import inplace_increment as increment
except ImportError:
increment = self.inplace_increment1d_slow
increment(x,idx, y)
out[0] = x
def inplace_increment1d_slow(self, x, idx, y):
# If `y` has as many dimensions as `x`, then we want to iterate # If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be # jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`. # broadcasted to fill all relevant rows of `x`.
...@@ -7150,7 +7162,6 @@ class AdvancedIncSubtensor1(Op): ...@@ -7150,7 +7162,6 @@ class AdvancedIncSubtensor1(Op):
else: else:
for i in idx: for i in idx:
x[i] += y x[i] += y
out[0] = x
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
x, y, ilist = ishapes x, y, ilist = ishapes
...@@ -7399,17 +7410,16 @@ class AdvancedIncSubtensor(Op): ...@@ -7399,17 +7410,16 @@ class AdvancedIncSubtensor(Op):
out, = out_ out, = out_
if not self.inplace: if not self.inplace:
out[0] = inputs[0].copy() out[0] = inputs[0].copy()
else:
raise NotImplementedError('In place computation is not'
' implemented')
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
else: else:
increment = None
try : try :
increment = gof.cutils_ext.inplace_increment from cutils_ext.cutils_ext import inplace_increment as increment
except: except ImportError:
raise NotImplementedError("Couldn't find raise NotImplementedError('Did not find inplace_increment.'
inplace_increment, update numpy.") 'Update numpy?')
increment(out[0], tuple(inputs[2:]), inputs[1]) increment(out[0], tuple(inputs[2:]), inputs[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论