提交 9a289b81 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2173 from RoyXue/theano_function_free

Issue 2169 a_theano_function.free() to free the temp memory
...@@ -19,6 +19,9 @@ import theano.compile.mode ...@@ -19,6 +19,9 @@ import theano.compile.mode
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput) In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
from theano.gof.op import ops_with_inner_function
import logging import logging
_logger = logging.getLogger('theano.compile.function_module') _logger = logging.getLogger('theano.compile.function_module')
...@@ -295,6 +298,7 @@ class Function(object): ...@@ -295,6 +298,7 @@ class Function(object):
self.profile = None # reassigned in FunctionMaker.create self.profile = None # reassigned in FunctionMaker.create
self.trust_input = False # If True, we don't check the input parameter self.trust_input = False # If True, we don't check the input parameter
self.name = None self.name = None
self.node_op_list = []
# We will be popping stuff off this `containers` object. It is a copy. # We will be popping stuff off this `containers` object. It is a copy.
containers = list(self.input_storage) containers = list(self.input_storage)
...@@ -451,6 +455,10 @@ class Function(object): ...@@ -451,6 +455,10 @@ class Function(object):
if input.update is not None: if input.update is not None:
self.n_returned_outputs -= 1 self.n_returned_outputs -= 1
for node in self.maker.fgraph.apply_nodes:
if node.op in ops_with_inner_function.keys():
self.node_op_list.append(node.op)
def __contains__(self, item): def __contains__(self, item):
return self.value.__contains__(item) return self.value.__contains__(item)
...@@ -678,8 +686,22 @@ class Function(object): ...@@ -678,8 +686,22 @@ class Function(object):
None, # this property itself is not settable None, # this property itself is not settable
doc="""dictionary-like access to the containers associated with Variables""") doc="""dictionary-like access to the containers associated with Variables""")
# pickling/deepcopy support for Function
def free(self):
"""
When allow_gc = False, clear the Variables in storage_map
"""
# 1.no allow_gc return False 2.has allow_gc, if allow_gc is False, return True
if not getattr(self.fn, 'allow_gc', True):
for key in self.fn.storage_map.keys():
if not isinstance(key, theano.gof.Constant):
self.fn.storage_map[key][0] = None
for node in self.node_op_list:
ops_with_inner_function[node.op].free()
# pickling/deepcopy support for Function
def _pickle_Function(f): def _pickle_Function(f):
#copy of the input storage list #copy of the input storage list
......
...@@ -399,6 +399,26 @@ class T_function(unittest.TestCase): ...@@ -399,6 +399,26 @@ class T_function(unittest.TestCase):
y = x * 2 y = x * 2
self.assertRaises(RuntimeError, function, [x], y, givens={x: x + 1}) self.assertRaises(RuntimeError, function, [x], y, givens={x: x + 1})
def test_free(self):
"""
Make test on free() function
"""
x = T.vector('x')
func = function([x], x+1)
func.fn.allow_gc = False
func([1])
check_list = []
for key, val in func.fn.storage_map.iteritems():
if not isinstance(key, theano.gof.Constant):
check_list.append(val)
assert any([val[0] for val in check_list])
func.free()
for key, val in func.fn.storage_map.iteritems():
if not isinstance(key, theano.gof.Constant):
assert (val[0] == None)
class T_picklefunction(unittest.TestCase): class T_picklefunction(unittest.TestCase):
......
...@@ -930,6 +930,34 @@ static PyMethodDef CLazyLinker_methods[] = { ...@@ -930,6 +930,34 @@ static PyMethodDef CLazyLinker_methods[] = {
}; };
#endif #endif
static PyObject *
CLazyLinker_get_allow_gc(CLazyLinker *self, void *closure)
{
return PyBool_FromLong(self->allow_gc);
}
static int
CLazyLinker_set_allow_gc(CLazyLinker *self, PyObject *value, void *closure)
{
if(!PyBool_Check(value))
return -1;
if (value == Py_True)
self->allow_gc = true;
else
self->allow_gc = false;
return 0;
}
static PyGetSetDef CLazyLinker_getset[] = {
{"allow_gc",
(getter)CLazyLinker_get_allow_gc,
(setter)CLazyLinker_set_allow_gc,
"do this function support allow_gc",
NULL},
{NULL, NULL, NULL, NULL} /* Sentinel */
};
static PyMemberDef CLazyLinker_members[] = { static PyMemberDef CLazyLinker_members[] = {
{(char*)"nodes", T_OBJECT_EX, offsetof(CLazyLinker, nodes), 0, {(char*)"nodes", T_OBJECT_EX, offsetof(CLazyLinker, nodes), 0,
(char*)"list of nodes"}, (char*)"list of nodes"},
...@@ -983,7 +1011,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { ...@@ -983,7 +1011,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
0, /* tp_iternext */ 0, /* tp_iternext */
0,//CLazyLinker_methods, /* tp_methods */ 0,//CLazyLinker_methods, /* tp_methods */
CLazyLinker_members, /* tp_members */ CLazyLinker_members, /* tp_members */
0, /* tp_getset */ CLazyLinker_getset, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
0, /* tp_dict */ 0, /* tp_dict */
0, /* tp_descr_get */ 0, /* tp_descr_get */
......
...@@ -184,6 +184,25 @@ def test_speed_lazy(): ...@@ -184,6 +184,25 @@ def test_speed_lazy():
time_linker('vmLinker_C', lambda : vm.VM_Linker(allow_gc=False, time_linker('vmLinker_C', lambda : vm.VM_Linker(allow_gc=False,
use_cloop=True)) use_cloop=True))
def test_allow_gc_cvm():
v = theano.tensor.vector()
f = theano.function([v], v + 1)
f([1])
n = list(f.maker.fgraph.apply_nodes)[0].outputs[0]
assert f.fn.storage_map[n][0] is None
assert f.fn.allow_gc is True
f.fn.allow_gc = False
assert f.fn.allow_gc is False
f([1])
assert f.fn.storage_map[n][0] is not None
f.fn.allow_gc = True
assert f.fn.allow_gc is True
f([1])
assert f.fn.storage_map[n][0] is None
run_memory_usage_tests = False run_memory_usage_tests = False
if run_memory_usage_tests: if run_memory_usage_tests:
# these are not normal unit tests, do not run them as part of standard # these are not normal unit tests, do not run them as part of standard
......
...@@ -924,6 +924,8 @@ class VM_Linker(link.LocalLinker): ...@@ -924,6 +924,8 @@ class VM_Linker(link.LocalLinker):
self.updated_vars self.updated_vars
) )
vm.storage_map = storage_map
return (vm, return (vm,
[link.Container(input, storage) [link.Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)], for input, storage in zip(fgraph.inputs, input_storage)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论