提交 012ff768 authored 作者: abalkin's avatar abalkin

Merge pull request #4 from lamblin/take-op-c-code-clean

Fix errors in C code
...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en" ...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en"
import sys import sys
import warnings import warnings
from itertools import izip from itertools import izip
from textwrap import dedent
import numpy import numpy
#from copy import copy as python_copy #from copy import copy as python_copy
...@@ -6592,6 +6593,14 @@ class AdvancedSubtensor1(Op): ...@@ -6592,6 +6593,14 @@ class AdvancedSubtensor1(Op):
x, ilist = ishapes x, ilist = ishapes
return [ilist + x[1:]] return [ilist + x[1:]]
def c_support_code(self):
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
# which is not defined. It should be NPY_MIN_LONG instead in that case.
return dedent("""\
#ifndef MIN_LONG
#define MIN_LONG NPY_MIN_LONG
#endif""")
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
if self.__class__ is not AdvancedSubtensor1: if self.__class__ is not AdvancedSubtensor1:
raise MethodNotDefined( raise MethodNotDefined(
...@@ -6607,10 +6616,9 @@ class AdvancedSubtensor1(Op): ...@@ -6607,10 +6616,9 @@ class AdvancedSubtensor1(Op):
// if all values fit. // if all values fit.
if (!PyArray_CanCastSafely(PyArray_TYPE(%(i_name)s), NPY_INTP)) if (!PyArray_CanCastSafely(PyArray_TYPE(%(i_name)s), NPY_INTP))
{ {
PyObject* py_min_val, py_max_val;
npy_int64 min_val, max_val; npy_int64 min_val, max_val;
py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, min_val); PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, NULL);
py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS, max_val); PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS, NULL);
min_val = PyLong_AsLongLong(py_min_val); min_val = PyLong_AsLongLong(py_min_val);
max_val = PyLong_AsLongLong(py_max_val); max_val = PyLong_AsLongLong(py_max_val);
Py_CLEAR(py_min_val); Py_CLEAR(py_min_val);
...@@ -6618,7 +6626,7 @@ class AdvancedSubtensor1(Op): ...@@ -6618,7 +6626,7 @@ class AdvancedSubtensor1(Op):
if ((min_val < NPY_MIN_INTP) || (max_val > NPY_MAX_INTP)) if ((min_val < NPY_MIN_INTP) || (max_val > NPY_MAX_INTP))
{ {
PyExc_SetErr(PyExc_IndexError, "Index contains values " PyErr_SetString(PyExc_IndexError, "Index contains values "
"that are bigger than the maximum array " "that are bigger than the maximum array "
"size on this system."); "size on this system.");
%(fail)s; %(fail)s;
...@@ -6664,7 +6672,7 @@ class AdvancedSubtensor1(Op): ...@@ -6664,7 +6672,7 @@ class AdvancedSubtensor1(Op):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 0, 5) return (0, 0, 6)
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论