提交 98b979fe authored 作者: John Salvatier's avatar John Salvatier

moved inplace_increment into theano

上级 e0eca757
......@@ -35,9 +35,140 @@ def compile_cutils():
return Py_BuildValue("i", failure);
}
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
static void npy_float64_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it)
{
int index = mit->size;
while (index--) {
((npy_float64*)mit->dataptr)[0] = ((npy_float64*)mit->dataptr)[0] + ((npy_float64*)it->dataptr)[0];
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
inplace_map_binop addition_funcs[] = {
npy_float64_inplace_add,
NULL};
int type_numbers[] = {
NPY_FLOAT64,
-1000};
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
if (mit->ait == NULL) {
return -1;
}
descr = PyArray_DESCR(mit->ait->ao);
Py_INCREF(descr);
arr = (PyArrayObject *)PyArray_FromAny(op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
}
if ((it = (PyArrayIterObject *)\
PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd)) == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
static PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i =0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index,
&inc)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
}
a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL;
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL;
}
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
}
static PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."},
{"inplace_increment", inplace_increment, METH_VARARGS,
"."}
{NULL, NULL, 0, NULL} /* Sentinel */
};"""
if PY3:
......
......@@ -44,12 +44,9 @@ python_any = any
python_all = all
# Define common subsets of dtypes (as strings).
int_dtypes = map(str, scal.int_types)
uint_dtypes = map(str, scal.uint_types)
float_dtypes = map(str, scal.float_types)
complex_dtypes = map(str, scal.complex_types)
continuous_dtypes = map(str, scal.continuous_types)
float_dtypes = map(str, scal.float_types)
discrete_dtypes = map(str, scal.discrete_types)
all_dtypes = map(str, scal.all_types)
int_dtypes = map(str, scal.int_types)
......@@ -7292,7 +7289,7 @@ class AdvancedSubtensor(Op):
def make_node(self, x, *index):
x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
......@@ -7300,7 +7297,7 @@ class AdvancedSubtensor(Op):
(x,) + index,
[tensor(dtype = x.type.dtype,
broadcastable = adv_broadcastable(x, index) )])
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -7308,6 +7305,18 @@ class AdvancedSubtensor(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs
def infer_shape(self, node, ishapes):
# Really special case
if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes
if len(xshp) == 2 and len(ind1shp) == 1 and len(ind2shp) == 1:
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
if node.inputs[2].owner is None:
return [ind2shp]
else:
return [ind1shp]
# Default case, we don't know
return node.fgraph.shape_feature.default_infer_shape(node, ishapes)
def perform(self, node, inputs, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论