提交 aa40fd06 authored 作者: John Salvatier's avatar John Salvatier

added imports for compiled version

上级 325a73a2
...@@ -14,6 +14,8 @@ def compile_cutils(): ...@@ -14,6 +14,8 @@ def compile_cutils():
"""Do just the compilation of cutils_ext""" """Do just the compilation of cutils_ext"""
code = """ code = """
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h"
extern "C"{ extern "C"{
static PyObject * static PyObject *
run_cthunk(PyObject *self, PyObject *args) run_cthunk(PyObject *self, PyObject *args)
...@@ -35,140 +37,144 @@ def compile_cutils(): ...@@ -35,140 +37,144 @@ def compile_cutils():
return Py_BuildValue("i", failure); return Py_BuildValue("i", failure);
} }
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *); #if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{ static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
int index = mit->size; {
while (index--) { int index = mit->size;
((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0]; while (index--) {
((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0];
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it); PyArray_MapIterNext(mit);
} PyArray_ITER_NEXT(it);
} }
}
inplace_map_binop addition_funcs[] = {
npy_float64_inplace_add, inplace_map_binop addition_funcs[] = {
NULL}; npy_float64_inplace_add,
NULL};
int type_numbers[] = {
NPY_FLOAT64, int type_numbers[] = {
-1000}; NPY_FLOAT64,
-1000};
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace) static int
{ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
PyArrayObject *arr = NULL; {
PyArrayIterObject *it; PyArrayObject *arr = NULL;
PyArray_Descr *descr; PyArrayIterObject *it;
if (mit->ait == NULL) { PyArray_Descr *descr;
return -1; if (mit->ait == NULL) {
} return -1;
descr = PyArray_DESCR(mit->ait->ao); }
Py_INCREF(descr); descr = PyArray_DESCR(mit->ait->ao);
arr = (PyArrayObject *)PyArray_FromAny(op, descr, Py_INCREF(descr);
0, 0, NPY_ARRAY_FORCECAST, NULL); arr = (PyArrayObject *)PyArray_FromAny(op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) { if (arr == NULL) {
return -1; return -1;
} }
if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
}
if ((it = (PyArrayIterObject *)\
PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd)) == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
} }
}
it = (PyArrayIterObject*)
static PyObject * PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
inplace_increment(PyObject *dummy, PyObject *args) if (it == NULL) {
{ Py_DECREF(arr);
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a; return -1;
inplace_map_binop add_inplace = NULL; }
int type_number = -1;
int i =0; (*add_inplace)(mit, it);
PyArrayMapIterObject * mit;
Py_DECREF(arr);
if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index, Py_DECREF(it);
&inc)) { return 0;
return NULL; }
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument"); static PyObject *
return NULL; inplace_increment(PyObject *dummy, PyObject *args)
} {
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
a = (PyArrayObject *) arg_a; PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) { int type_number = -1;
return NULL; int i =0;
} PyArrayMapIterObject * mit;
if (PyArray_NDIM(a) == 0) { if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index,
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed."); &inc)) {
return NULL; return NULL;
} }
type_number = PyArray_TYPE(a); if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
}
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) { a = (PyArrayObject *) arg_a;
add_inplace = addition_funcs[i];
break; if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
} return NULL;
i++ ; }
}
if (PyArray_NDIM(a) == 0) {
if (add_inplace == NULL) { PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
PyErr_SetString(PyExc_TypeError, "unsupported type for a"); return NULL;
return NULL; }
} type_number = PyArray_TYPE(a);
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
if (mit == NULL) {
goto fail;
} while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (map_increment(mit, inc, add_inplace) != 0) { if (type_number == type_numbers[i]) {
goto fail; add_inplace = addition_funcs[i];
} break;
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
} }
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
}
#endif
static PyMethodDef CutilsExtMethods[] = { static PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS, {"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."}, "Run a theano cthunk."},
{"inplace_increment", inplace_increment, METH_VARARGS, #if NPY_API_VERSION >= 0x00000008
"."} {"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
};""" };"""
if PY3: if PY3:
...@@ -198,9 +204,10 @@ def compile_cutils(): ...@@ -198,9 +204,10 @@ def compile_cutils():
{ {
(void) Py_InitModule("cutils_ext", CutilsExtMethods); (void) Py_InitModule("cutils_ext", CutilsExtMethods);
} }
} }
""" """
import cmodule
loc = os.path.join(config.compiledir, 'cutils_ext') loc = os.path.join(config.compiledir, 'cutils_ext')
if not os.path.exists(loc): if not os.path.exists(loc):
os.mkdir(loc) os.mkdir(loc)
......
...@@ -24,6 +24,8 @@ from theano import compile, printing ...@@ -24,6 +24,8 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext
# We use these exceptions as well. # We use these exceptions as well.
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
import theano.scalar.sharedvar import theano.scalar.sharedvar
...@@ -7139,19 +7141,28 @@ class AdvancedIncSubtensor1(Op): ...@@ -7139,19 +7141,28 @@ class AdvancedIncSubtensor1(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
x[idx] = y x[idx] = y
else: else:
# If `y` has as many dimensions as `x`, then we want to iterate try :
# jointly on `x` and `y`. Otherwise, it means `y` should be from cutils_ext.cutils_ext import inplace_increment as increment
# broadcasted to fill all relevant rows of `x`. except ImportError:
assert y.ndim <= x.ndim # Should be guaranteed by `make_node` increment = self.inplace_increment1d_slow
if y.ndim == x.ndim:
assert len(y) == len(idx) increment(x,idx, y)
for (j, i) in enumerate(idx):
x[i] += y[j]
else:
for i in idx:
x[i] += y
out[0] = x out[0] = x
def inplace_increment1d_slow(self, x, idx, y):
# If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim:
assert len(y) == len(idx)
for (j, i) in enumerate(idx):
x[i] += y[j]
else:
for i in idx:
x[i] += y
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
x, y, ilist = ishapes x, y, ilist = ishapes
return [x] return [x]
...@@ -7399,17 +7410,16 @@ class AdvancedIncSubtensor(Op): ...@@ -7399,17 +7410,16 @@ class AdvancedIncSubtensor(Op):
out, = out_ out, = out_
if not self.inplace: if not self.inplace:
out[0] = inputs[0].copy() out[0] = inputs[0].copy()
else:
raise NotImplementedError('In place computation is not'
' implemented')
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
else: else:
try : increment = None
increment = gof.cutils_ext.inplace_increment try :
except: from cutils_ext.cutils_ext import inplace_increment as increment
raise NotImplementedError("Couldn't find except ImportError:
inplace_increment, update numpy.") raise NotImplementedError('Did not find inplace_increment.'
'Update numpy?')
increment(out[0], tuple(inputs[2:]), inputs[1]) increment(out[0], tuple(inputs[2:]), inputs[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论