提交 b76520e7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change storage clearing to always be done.

Also add patch for numpy API fixes to help people playing with scan_perform.pyx
上级 91172b5f
diff --git a/theano/scan_module/scan_perform.c b/theano/scan_module/scan_perform.c
index aaebb43..2d06b29 100644
--- a/theano/scan_module/scan_perform.c
+++ b/theano/scan_module/scan_perform.c
@@ -5595,7 +5595,7 @@ static int __pyx_pf_5numpy_7ndarray___getbuffer__(PyArrayObject *__pyx_v_self, P
* cdef list stack
* cdef int offset
*/
- __pyx_t_4 = ((PyObject *)__pyx_v_self->descr);
+ __pyx_t_4 = ((PyObject *)PyArray_DESCR(__pyx_v_self));
__Pyx_INCREF(__pyx_t_4);
__pyx_v_descr = ((PyArray_Descr *)__pyx_t_4);
__pyx_t_4 = 0;
@@ -7147,7 +7147,7 @@ static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_a
* arr.base = baseptr
*
*/
- Py_XDECREF(__pyx_v_arr->base);
+ Py_XDECREF(PyArray_BASE(__pyx_v_arr));
/* "/home/anakha/.local/lib/python2.7/site-packages/Cython/Includes/numpy/__init__.pxd":974
* baseptr = <PyObject*>base
@@ -7156,7 +7156,11 @@ static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_a
*
* cdef inline object get_array_base(ndarray arr):
*/
- __pyx_v_arr->base = __pyx_v_baseptr;
+#if NPY_API < 0x00000007
+ PyArray_BASE(__pyx_v_arr) = __pyx_v_baseptr;
+#else
+ PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_baseptr);
+#endif
/* "/home/anakha/.local/lib/python2.7/site-packages/Cython/Includes/numpy/__init__.pxd":966
*
@@ -7191,7 +7195,7 @@ static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__py
* return None
* else:
*/
- __pyx_t_1 = ((__pyx_v_arr->base == NULL) != 0);
+ __pyx_t_1 = ((PyArray_BASE(__pyx_v_arr) == NULL) != 0);
if (__pyx_t_1) {
/* "/home/anakha/.local/lib/python2.7/site-packages/Cython/Includes/numpy/__init__.pxd":978
@@ -7214,8 +7218,8 @@ static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__py
* return <object>arr.base # <<<<<<<<<<<<<<
*/
__Pyx_XDECREF(__pyx_r);
- __Pyx_INCREF(((PyObject *)__pyx_v_arr->base));
- __pyx_r = ((PyObject *)__pyx_v_arr->base);
+ __Pyx_INCREF(((PyObject *)PyArray_BASE(__pyx_v_arr)));
+ __pyx_r = ((PyObject *)PyArray_BASE(__pyx_v_arr));
goto __pyx_L0;
}
...@@ -1104,13 +1104,12 @@ class Scan(PureOp): ...@@ -1104,13 +1104,12 @@ class Scan(PureOp):
# little trick that I used # little trick that I used
outs[idx][0] = outs[idx][0][:-(n_steps - i)] outs[idx][0] = outs[idx][0][:-(n_steps - i)]
# Make sure to release storage if allow_gc is True, like # We never reuse the input or output storage of the
# Function.__call__ does. # inner function so we clear it.
if getattr(fn, 'allow_gc', False): for i_s in input_storage:
for i in input_storage: i_s.storage[0] = None
i.storage[0] = None for o_s in output_storage:
for o in output_storage: o_s.storage[0] = None
o.storage[0] = None
t_call = time.time() - t0_call t_call = time.time() - t0_call
# NOTE: make this match what's in function_module.Function # NOTE: make this match what's in function_module.Function
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ import copy ...@@ -62,7 +62,7 @@ import copy
def get_version(): def get_version():
return 0.282 return 0.283
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -456,13 +456,12 @@ def perform( ...@@ -456,13 +456,12 @@ def perform(
sh0 = outs[idx][0].shape[0] sh0 = outs[idx][0].shape[0]
outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)] outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)]
# Make sure to release storage if allow_gc is True, like # We never reuse the input or output storage of the
# Function.__call__ does. # inner function so we clear it.
if getattr(fn, 'allow_gc', False): for i_s in input_storage:
for i in input_storage: i_s.storage[0] = None
i.storage[0] = None for o_s in output_storage:
for o in output_storage: o_s.storage[0] = None
o.storage[0] = None
t_call = time.time() - t0_call t_call = time.time() - t0_call
......
...@@ -16,7 +16,7 @@ _logger = logging.getLogger('theano.scan_module.scan_perform') ...@@ -16,7 +16,7 @@ _logger = logging.getLogger('theano.scan_module.scan_perform')
_logger.setLevel(logging.WARN) _logger.setLevel(logging.WARN)
version = 0.282 # must match constant returned in function get_version() version = 0.283 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论