提交 e02ce0d8 authored 作者: abergeron's avatar abergeron

Merge pull request #1712 from nouiz/nonetype

Add C code to NoneTypeT and fix MaxAndArgmax c_code to work with it.
......@@ -1435,18 +1435,21 @@ class MaxAndArgmax(Op):
x, axis = inp
max, argmax = out
fail = sub["fail"]
assert NoneConst.equals(node.inputs[1]) or node.inputs[1].ndim == 0
ret = """
int axis;
if((PyObject*)%(axis)s == Py_None){
axis = NPY_MAXDIMS;
}else{
if NoneConst.equals(node.inputs[1]):
axis_code = "axis = NPY_MAXDIMS;"
else:
assert node.inputs[1].ndim == 0
axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s
}
}
""" % locals()
ret = """
int axis;
%(axis_code)s
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){
PyErr_SetString(PyExc_ValueError,
......
......@@ -2,7 +2,7 @@
# Slice type and Op. None Type and NoneConst.
#
import theano
from theano.gof import Apply, Constant, Op, Type
from theano.gof import Apply, Constant, Generic, Op, Type
from theano.gradient import DisconnectedType
......@@ -55,7 +55,10 @@ class SliceType(Type):
slicetype = SliceType()
class NoneTypeT(Type):
class NoneTypeT(Generic):
"""
Inherit from Generic to have c code working.
"""
def filter(self, x, strict=False, allow_downcast=None):
if x is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论