提交 a4178507 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2124 from abergeron/fix_scan_leak

Fix memory leak in scan related to the input and output storage of the inner function.
diff --git a/theano/scan_module/scan_perform.c b/theano/scan_module/scan_perform.c
index aaebb43..2d06b29 100644
--- a/theano/scan_module/scan_perform.c
+++ b/theano/scan_module/scan_perform.c
@@ -5595,7 +5595,7 @@ static int __pyx_pf_5numpy_7ndarray___getbuffer__(PyArrayObject *__pyx_v_self, P
* cdef list stack
* cdef int offset
*/
- __pyx_t_4 = ((PyObject *)__pyx_v_self->descr);
+ __pyx_t_4 = ((PyObject *)PyArray_DESCR(__pyx_v_self));
__Pyx_INCREF(__pyx_t_4);
__pyx_v_descr = ((PyArray_Descr *)__pyx_t_4);
__pyx_t_4 = 0;
@@ -7147,7 +7147,7 @@ static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_a
* arr.base = baseptr
*
*/
- Py_XDECREF(__pyx_v_arr->base);
+ Py_XDECREF(PyArray_BASE(__pyx_v_arr));
/* "/home/anakha/.local/lib/python2.7/site-packages/Cython/Includes/numpy/__init__.pxd":974
* baseptr = <PyObject*>base
@@ -7156,7 +7156,11 @@ static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_a
*
* cdef inline object get_array_base(ndarray arr):
*/
- __pyx_v_arr->base = __pyx_v_baseptr;
+#if NPY_API < 0x00000007
+ PyArray_BASE(__pyx_v_arr) = __pyx_v_baseptr;
+#else
+ PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_baseptr);
+#endif
/* "/home/anakha/.local/lib/python2.7/site-packages/Cython/Includes/numpy/__init__.pxd":966
*
@@ -7191,7 +7195,7 @@ static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__py
* return None
* else:
*/
- __pyx_t_1 = ((__pyx_v_arr->base == NULL) != 0);
+ __pyx_t_1 = ((PyArray_BASE(__pyx_v_arr) == NULL) != 0);
if (__pyx_t_1) {
/* "/home/anakha/.local/lib/python2.7/site-packages/Cython/Includes/numpy/__init__.pxd":978
@@ -7214,8 +7218,8 @@ static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__py
* return <object>arr.base # <<<<<<<<<<<<<<
*/
__Pyx_XDECREF(__pyx_r);
- __Pyx_INCREF(((PyObject *)__pyx_v_arr->base));
- __pyx_r = ((PyObject *)__pyx_v_arr->base);
+ __Pyx_INCREF(((PyObject *)PyArray_BASE(__pyx_v_arr)));
+ __pyx_r = ((PyObject *)PyArray_BASE(__pyx_v_arr));
goto __pyx_L0;
}
...@@ -1104,6 +1104,13 @@ class Scan(PureOp): ...@@ -1104,6 +1104,13 @@ class Scan(PureOp):
# little trick that I used # little trick that I used
outs[idx][0] = outs[idx][0][:-(n_steps - i)] outs[idx][0] = outs[idx][0][:-(n_steps - i)]
# We never reuse the input or output storage of the
# inner function so we clear it.
for i_s in input_storage:
i_s.storage[0] = None
for o_s in output_storage:
o_s.storage[0] = None
t_call = time.time() - t0_call t_call = time.time() - t0_call
# NOTE: make this match what's in function_module.Function # NOTE: make this match what's in function_module.Function
# and this little string helps us to find this spot: # and this little string helps us to find this spot:
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ import copy ...@@ -62,7 +62,7 @@ import copy
def get_version(): def get_version():
return 0.281 return 0.283
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -456,6 +456,13 @@ def perform( ...@@ -456,6 +456,13 @@ def perform(
sh0 = outs[idx][0].shape[0] sh0 = outs[idx][0].shape[0]
outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)] outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)]
# We never reuse the input or output storage of the
# inner function so we clear it.
for i_s in input_storage:
i_s.storage[0] = None
for o_s in output_storage:
o_s.storage[0] = None
t_call = time.time() - t0_call t_call = time.time() - t0_call
if hasattr(fnct.maker, 'profile'): if hasattr(fnct.maker, 'profile'):
......
...@@ -16,7 +16,7 @@ _logger = logging.getLogger('theano.scan_module.scan_perform') ...@@ -16,7 +16,7 @@ _logger = logging.getLogger('theano.scan_module.scan_perform')
_logger.setLevel(logging.WARN) _logger.setLevel(logging.WARN)
version = 0.281 # must match constant returned in function get_version() version = 0.283 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
...@@ -18,6 +18,7 @@ from theano.compile.pfunc import rebuild_collect_shared ...@@ -18,6 +18,7 @@ from theano.compile.pfunc import rebuild_collect_shared
from theano.gof.python25 import any from theano.gof.python25 import any
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
import theano.scalar.sharedvar import theano.scalar.sharedvar
from theano.scan_module.scan_op import Scan
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.compat import PY3 from theano.compat import PY3
...@@ -261,6 +262,45 @@ class T_Scan(unittest.TestCase): ...@@ -261,6 +262,45 @@ class T_Scan(unittest.TestCase):
theano_values = my_f(state, steps) theano_values = my_f(state, steps)
utt.assert_allclose(numpy_values, theano_values) utt.assert_allclose(numpy_values, theano_values)
# Test that the inner input_storage and output_storage are
# properly cleared
def test_inner_storage_leak(self):
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2,
[],
state,
[],
n_steps=n_steps)
f = theano.function([state, n_steps],
output,
updates=updates,
allow_input_downcast=True)
scan_node = [node for node in f.maker.fgraph.toposort()
if isinstance(node.op, Scan)]
assert len(scan_node) == 1
scan_node = scan_node[0]
# Make sure they start out as None
assert all(i.value is None for i in scan_node.op.fn.input_storage)
assert all(o.value is None for o in scan_node.op.fn.output_storage)
rng = numpy.random.RandomState(utt.fetch_seed())
state = rng.uniform()
steps = 5
f(state, steps)
# And that they stay that way
assert all(i.value is None for i in scan_node.op.fn.input_storage)
assert all(o.value is None for o in scan_node.op.fn.output_storage)
# generator network, only one output , type scalar ; no sequence or # generator network, only one output , type scalar ; no sequence or
# non sequence arguments # non sequence arguments
def test_generator_one_output_scalar(self): def test_generator_one_output_scalar(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论