提交 b01cf487 authored 作者: Frederic's avatar Frederic

Use new Numpy c api

上级 4e27ad46
...@@ -67,9 +67,9 @@ you should check the strides and alignment. ...@@ -67,9 +67,9 @@ you should check the strides and alignment.
if (!%(y)s) if (!%(y)s)
%(fail)s; %(fail)s;
{//New scope needed to make compilation work {//New scope needed to make compilation work
dtype_%(y)s * y = (dtype_%(y)s*)%(y)s->data; dtype_%(y)s * y = (dtype_%(y)s*)PyArray_DATA(%(y)s);
dtype_%(x)s * x = (dtype_%(x)s*)%(x)s->data; dtype_%(x)s * x = (dtype_%(x)s*)PyArray_DATA(%(x)s);
for (int i = 2; i < %(x)s->dimensions[0]; ++i) for (int i = 2; i < PyArray_DIMS(%(x)s)[0]; ++i)
y[i] = y[i-1]*y[i-2] + x[i]; y[i] = y[i-1]*y[i-2] + x[i];
} }
""" % locals() """ % locals()
......
...@@ -3905,7 +3905,7 @@ class Reshape(Op): ...@@ -3905,7 +3905,7 @@ class Reshape(Op):
} }
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, %(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape,
PyArray_CORDER); NPY_CORDER);
if (!%(z)s) if (!%(z)s)
{ {
//The error message should have been set by PyArray_Newshape //The error message should have been set by PyArray_Newshape
......
...@@ -118,14 +118,14 @@ class SoftmaxWithBias(gof.Op): ...@@ -118,14 +118,14 @@ class SoftmaxWithBias(gof.Op):
PyErr_SetString(PyExc_ValueError, "b not 1d tensor"); PyErr_SetString(PyExc_ValueError, "b not 1d tensor");
%(fail)s; %(fail)s;
} }
if ((PyArray_DESCR(%(x)s)->type_num != NPY_DOUBLE) && if ((PyArray_TYPE(%(x)s) != NPY_DOUBLE) &&
(PyArray_DESCR(%(x)s)->type_num != NPY_FLOAT)) (PyArray_TYPE(%(x)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, "not a float"); PyErr_SetString(PyExc_TypeError, "not a float");
%(fail)s; %(fail)s;
} }
if ((PyArray_DESCR(%(b)s)->type_num != NPY_DOUBLE) && if ((PyArray_TYPE(%(b)s) != NPY_DOUBLE) &&
(PyArray_DESCR(%(b)s)->type_num != NPY_FLOAT)) (PyArray_TYPE(%(b)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, "b not float"); PyErr_SetString(PyExc_TypeError, "b not float");
%(fail)s; %(fail)s;
...@@ -264,15 +264,15 @@ class SoftmaxGrad(gof.Op): ...@@ -264,15 +264,15 @@ class SoftmaxGrad(gof.Op):
dy, sm = inp dy, sm = inp
dx, = out dx, = out
return ''' return '''
if ((PyArray_DESCR(%(dy)s)->type_num != NPY_DOUBLE) && if ((PyArray_TYPE(%(dy)s) != NPY_DOUBLE) &&
(PyArray_DESCR(%(dy)s)->type_num != NPY_FLOAT)) (PyArray_TYPE(%(dy)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"types should be float or float64"); "types should be float or float64");
%(fail)s; %(fail)s;
} }
if ((PyArray_DESCR(%(sm)s)->type_num != NPY_DOUBLE) && if ((PyArray_TYPE(%(sm)s) != NPY_DOUBLE) &&
(PyArray_DESCR(%(sm)s)->type_num != NPY_FLOAT)) (PyArray_TYPE(%(sm)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"types should be float or float64"); "types should be float or float64");
...@@ -395,23 +395,23 @@ class Softmax(gof.Op): ...@@ -395,23 +395,23 @@ class Softmax(gof.Op):
#TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """ init_decl = """
npy_intp* Nx = %(x)s->dimensions; npy_intp* Nx = PyArray_DIMS(%(x)s);
if (%(x)s->nd != 2) if (PyArray_NDIM(%(x)s) != 2)
{ {
PyErr_SetString(PyExc_ValueError, "not a 2d tensor"); PyErr_SetString(PyExc_ValueError, "not a 2d tensor");
%(fail)s; %(fail)s;
} }
if ((%(x)s->descr->type_num != PyArray_DOUBLE) && if ((PyArray_TYPE(%(x)s) != NPY_DOUBLE) &&
(%(x)s->descr->type_num != PyArray_FLOAT)) (PyArray_TYPE(%(x)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, "not a float"); PyErr_SetString(PyExc_TypeError, "not a float");
%(fail)s; %(fail)s;
} }
if ((NULL == %(sm)s) if ((NULL == %(sm)s)
|| (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) || (PyArray_DIMS(%(sm)s)[0] != PyArray_DIMS(%(x)s)[0])
|| (%(sm)s->dimensions[1] != %(x)s->dimensions[1])) || (PyArray_DIMS(%(sm)s)[1] != PyArray_DIMS(%(x)s)[1]))
{ {
if (NULL != %(sm)s) Py_XDECREF(%(sm)s); if (NULL != %(sm)s) Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
...@@ -431,13 +431,13 @@ class Softmax(gof.Op): ...@@ -431,13 +431,13 @@ class Softmax(gof.Op):
double sum = 0.0; double sum = 0.0;
bool discount_max = false; bool discount_max = false;
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(%(x)s->data + %(x)s->strides[0] * i); const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i);
dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(%(sm)s->data + %(sm)s->strides[0] * i); dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i);
""" """
inside_row_loop = """ inside_row_loop = """
npy_intp Sx = %(x)s->strides[1]/sizeof(dtype_%(x)s); npy_intp Sx = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
npy_intp Ssm = %(sm)s->strides[1]/sizeof(dtype_%(sm)s); npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
size_t row_max_j=0; size_t row_max_j=0;
dtype_%(sm)s row_max = x_i[0]; dtype_%(sm)s row_max = x_i[0];
...@@ -1018,15 +1018,15 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -1018,15 +1018,15 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
y_idx_type = node.inputs[2].type.dtype_specs()[1] y_idx_type = node.inputs[2].type.dtype_specs()[1]
return """ return """
if ((PyArray_DESCR(%(dnll)s)->type_num != NPY_DOUBLE) && if ((PyArray_TYPE(%(dnll)s) != NPY_DOUBLE) &&
(PyArray_DESCR(%(dnll)s)->type_num != NPY_FLOAT)) (PyArray_TYPE(%(dnll)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"dnll type should be float32 or float64"); "dnll type should be float32 or float64");
%(fail)s; %(fail)s;
} }
if ((PyArray_DESCR(%(sm)s)->type_num != NPY_DOUBLE) && if ((PyArray_TYPE(%(sm)s) != NPY_DOUBLE) &&
(PyArray_DESCR(%(sm)s)->type_num != NPY_FLOAT)) (PyArray_TYPE(%(sm)s) != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"sm type should be float32 or float64"); "sm type should be float32 or float64");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论