提交 c3a45dd3 authored 作者: sebastien's avatar sebastien

c_code for AdvancedIncSubtensor1

上级 86b7ae60
......@@ -138,9 +138,10 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
}
static PyObject *
PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
fprintf(stderr, "prout1\\n");
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
......@@ -184,6 +185,7 @@ inplace_increment(PyObject *dummy, PyObject *args)
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
Py_INCREF(mit);
if (mit == NULL) {
goto fail;
}
......
from copy import copy
from itertools import izip
import os
import sys
from textwrap import dedent
import warnings
......@@ -1750,17 +1751,258 @@ class AdvancedIncSubtensor1(Op):
opname, x_.type.ndim, y_.type.ndim))
return Apply(self, [x_, y_, ilist_], [x_.type()])
def copy_of_x(self, x):
"""
:param x: a string giving the name of a C variable
pointing to an array
:return: C code expression to make a copy of x
Base class uses PyArrayObject *, subclasses may override for
different types of arrays.
"""
# Parameters of PyArrary_FromAny are:
# array
# dtype: we pass NULL to say any dtype is acceptable, so the existing
# dtype will be copied
# min_depth: we pass 0 to have this parameter ignored
# max_depth: we pass 0 to have this parameter ignored
# requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
# context: this is almost always NULL, I'm not sure what it's used for
return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def c_support_code(self):
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160', 'complex192', 'complex512']]
inplace_map_template = """
#if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
%(op)s
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
#endif
"""
floatadd = "((%(type)s*)mit->dataptr)[0] = inc_or_set * ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
complexadd = """
((%(type)s*)mit->dataptr)[0].real = inc_or_set * ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = inc_or_set * ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join(["""
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"""NULL};
""")
type_number_array = ("static int type_numbers[] = {" +
''.join(["""
#if defined(%(typen)s)
%(typen)s,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};")
return ("""
#include <Python.h>
#include "numpy/arrayobject.h"
extern "C" //{
PyObject *
run_cthunk(PyObject *self, PyObject *args)
{
PyObject *py_cthunk = NULL;
if(!PyArg_ParseTuple(args,"O",&py_cthunk))
return NULL;
if (!PyCObject_Check(py_cthunk)) {
PyErr_SetString(PyExc_ValueError,
"Argument to run_cthunk must be a PyCObject.");
return NULL;
}
void * ptr_addr = PyCObject_AsVoidPtr(py_cthunk);
int (*fn)(void*) = (int (*)(void*))(ptr_addr);
void* it = PyCObject_GetDesc(py_cthunk);
int failure = fn(it);
return Py_BuildValue("i", failure);
}
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int f);
""" + fns + fn_array + type_number_array +
"""
int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
if (mit->ait == NULL) {
return -1;
}
descr = PyArray_DESCR(mit->ait->ao);
Py_INCREF(descr);
arr = (PyArrayObject *)PyArray_FromAny(op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it, inc_or_set);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
int inc_or_set;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i = 0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOOi", &arg_a, &index,
&inc, &inc_or_set)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
}
a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL;
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL;
}
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
Py_INCREF(mit); // Should I INCREF here?
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
}
#endif
PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."},
#if NPY_API_VERSION >= 0x00000008
{"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */
};""")
def c_code(self, node, name, input_names, output_names, sub):
x, y, idx = input_names
out = output_names[0]
fail = sub['fail']
inc_or_set = 1 - self.set_instead_of_inc
if self.inplace: # convert bool to int
inplace = 1
else:
inplace = 0
copy_of_x = self.copy_of_x(x)
return """
PyObject *arglist = PyTuple_Pack(3,%(x)s, %(idx)s, %(y)s);
//PyObject *result /*Will be PyNone*/
//result = PyEval_CallObject(inplace_increment, arglist);
//Py_DECREF(arglist)
//Py_DECREF(result)
if (%(inplace)s)
{
if (%(x)s != %(out)s)
{
Py_XDECREF(%(out)s);
Py_INCREF(%(x)s);
%(out)s = %(x)s;
}
}
else
{
Py_XDECREF(%(out)s);
%(out)s = %(copy_of_x)s;
}
PyObject *arglist = Py_BuildValue("OOOi",%(out)s, %(idx)s, %(y)s, %(inc_or_set)d);
inplace_increment(NULL, arglist);
""" % locals()
def perform(self, node, inp, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论