提交 c3a45dd3 authored 作者: sebastien's avatar sebastien

c_code for AdvancedIncSubtensor1

上级 86b7ae60
...@@ -138,9 +138,10 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp ...@@ -138,9 +138,10 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
} }
static PyObject * PyObject *
inplace_increment(PyObject *dummy, PyObject *args) inplace_increment(PyObject *dummy, PyObject *args)
{ {
fprintf(stderr, "prout1\\n");
PyObject *arg_a = NULL, *index=NULL, *inc=NULL; PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a; PyArrayObject *a;
inplace_map_binop add_inplace = NULL; inplace_map_binop add_inplace = NULL;
...@@ -184,6 +185,7 @@ inplace_increment(PyObject *dummy, PyObject *args) ...@@ -184,6 +185,7 @@ inplace_increment(PyObject *dummy, PyObject *args)
return NULL; return NULL;
} }
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index); mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
Py_INCREF(mit);
if (mit == NULL) { if (mit == NULL) {
goto fail; goto fail;
} }
......
from copy import copy from copy import copy
from itertools import izip from itertools import izip
import os
import sys import sys
from textwrap import dedent from textwrap import dedent
import warnings import warnings
...@@ -1751,16 +1752,257 @@ class AdvancedIncSubtensor1(Op): ...@@ -1751,16 +1752,257 @@ class AdvancedIncSubtensor1(Op):
return Apply(self, [x_, y_, ilist_], [x_.type()]) return Apply(self, [x_, y_, ilist_], [x_.type()])
def copy_of_x(self, x):
"""
:param x: a string giving the name of a C variable
pointing to an array
:return: C code expression to make a copy of x
Base class uses PyArrayObject *, subclasses may override for
different types of arrays.
"""
# Parameters of PyArrary_FromAny are:
# array
# dtype: we pass NULL to say any dtype is acceptable, so the existing
# dtype will be copied
# min_depth: we pass 0 to have this parameter ignored
# max_depth: we pass 0 to have this parameter ignored
# requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
# context: this is almost always NULL, I'm not sure what it's used for
return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def c_support_code(self):
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160', 'complex192', 'complex512']]
inplace_map_template = """
#if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
%(op)s
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
}
}
#endif
"""
floatadd = "((%(type)s*)mit->dataptr)[0] = inc_or_set * ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
complexadd = """
((%(type)s*)mit->dataptr)[0].real = inc_or_set * ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = inc_or_set * ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join(["""
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"""NULL};
""")
type_number_array = ("static int type_numbers[] = {" +
''.join(["""
#if defined(%(typen)s)
%(typen)s,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};")
return ("""
#include <Python.h>
#include "numpy/arrayobject.h"
extern "C" //{
PyObject *
run_cthunk(PyObject *self, PyObject *args)
{
PyObject *py_cthunk = NULL;
if(!PyArg_ParseTuple(args,"O",&py_cthunk))
return NULL;
if (!PyCObject_Check(py_cthunk)) {
PyErr_SetString(PyExc_ValueError,
"Argument to run_cthunk must be a PyCObject.");
return NULL;
}
void * ptr_addr = PyCObject_AsVoidPtr(py_cthunk);
int (*fn)(void*) = (int (*)(void*))(ptr_addr);
void* it = PyCObject_GetDesc(py_cthunk);
int failure = fn(it);
return Py_BuildValue("i", failure);
}
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int f);
""" + fns + fn_array + type_number_array +
"""
int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArray_Descr *descr;
if (mit->ait == NULL) {
return -1;
}
descr = PyArray_DESCR(mit->ait->ao);
Py_INCREF(descr);
arr = (PyArrayObject *)PyArray_FromAny(op, descr,
0, 0, NPY_ARRAY_FORCECAST, NULL);
if (arr == NULL) {
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
}
}
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
return -1;
}
(*add_inplace)(mit, it, inc_or_set);
Py_DECREF(arr);
Py_DECREF(it);
return 0;
}
PyObject *
inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
int inc_or_set;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i = 0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOOi", &arg_a, &index,
&inc, &inc_or_set)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
return NULL;
}
a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL;
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL;
}
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
Py_INCREF(mit); // Should I INCREF here?
if (mit == NULL) {
goto fail;
}
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(mit);
return NULL;
}
#endif
PyMethodDef CutilsExtMethods[] = {
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."},
#if NPY_API_VERSION >= 0x00000008
{"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */
};""")
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
x, y, idx = input_names x, y, idx = input_names
out = output_names[0] out = output_names[0]
fail = sub['fail'] fail = sub['fail']
inc_or_set = 1 - self.set_instead_of_inc
if self.inplace: # convert bool to int
inplace = 1
else:
inplace = 0
copy_of_x = self.copy_of_x(x)
return """ return """
PyObject *arglist = PyTuple_Pack(3,%(x)s, %(idx)s, %(y)s); if (%(inplace)s)
//PyObject *result /*Will be PyNone*/ {
//result = PyEval_CallObject(inplace_increment, arglist); if (%(x)s != %(out)s)
//Py_DECREF(arglist) {
//Py_DECREF(result) Py_XDECREF(%(out)s);
Py_INCREF(%(x)s);
%(out)s = %(x)s;
}
}
else
{
Py_XDECREF(%(out)s);
%(out)s = %(copy_of_x)s;
}
PyObject *arglist = Py_BuildValue("OOOi",%(out)s, %(idx)s, %(y)s, %(inc_or_set)d);
inplace_increment(NULL, arglist);
""" % locals() """ % locals()
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论