提交 17d1e8c4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the code written while way too tired.

上级 613b4e93
...@@ -183,7 +183,7 @@ raise_with_op.print_thunk_trace = False ...@@ -183,7 +183,7 @@ raise_with_op.print_thunk_trace = False
class Linker(object): class Linker(object):
"""WRITEME""" """WRITEME"""
def clone(allow_gc=undef): def clone(self, allow_gc=undef):
new = copy(self) new = copy(self)
if allow_gc is not undef: if allow_gc is not undef:
new.allow_gc = allow_gc new.allow_gc = allow_gc
......
...@@ -38,9 +38,10 @@ _logger = logging.getLogger('theano.scan_module.scan_op') ...@@ -38,9 +38,10 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
from theano.configparser import AddConfigVar, BoolParam from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('scan.allow_gc', AddConfigVar('scan.allow_gc',
"Allow/disallow gc inside of Scan", "Allow/disallow gc inside of Scan (default: config.allow_gc)",
BoolParam(False)) BoolParam(lambda: config.allow_gc))
class Scan(PureOp): class Scan(PureOp):
...@@ -110,7 +111,7 @@ class Scan(PureOp): ...@@ -110,7 +111,7 @@ class Scan(PureOp):
isinstance(mode_instance, compile.profilemode.ProfileMode)): isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode( mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer, optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker) linker=mode_instance.linker.clone(allow_gc=allow_gc))
compile.profilemode.prof_mode_instance_to_print.append( compile.profilemode.prof_mode_instance_to_print.append(
mode_instance) mode_instance)
self.mode_instance = mode_instance self.mode_instance = mode_instance
...@@ -119,10 +120,9 @@ class Scan(PureOp): ...@@ -119,10 +120,9 @@ class Scan(PureOp):
else: else:
self.mode_instance.message = "Scan sub profile" self.mode_instance.message = "Scan sub profile"
else: else:
mode_instance = mode_instance.__type__( self.mode_instance = type(mode_instance)(
optimizer=mode_instance.provided_optimizer, optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker.clone(allow_gc=allow_gc)) linker=mode_instance.linker.clone(allow_gc=allow_gc))
self.mode_instance = mode_instance
if not hasattr(self, 'name') or self.name is None: if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn' self.name = 'scan_fn'
...@@ -632,10 +632,17 @@ class Scan(PureOp): ...@@ -632,10 +632,17 @@ class Scan(PureOp):
p = self.execute p = self.execute
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): # Big ugly hack since we can't get the real value of allow_gc
# for the englobing function.
allow_gc = config.allow_gc
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
allow_gc=allow_gc):
r = p(n, [x[0] for x in i], o) r = p(n, [x[0] for x in i], o)
for o in node.outputs: for o in node.outputs:
compute_map[o][0] = True compute_map[o][0] = True
if allow_gc:
self.fn.free()
return r return r
rval.inputs = node_input_storage rval.inputs = node_input_storage
rval.outputs = node_output_storage rval.outputs = node_output_storage
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论