提交 17d1e8c4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the code written while way too tired.

上级 613b4e93
......@@ -183,7 +183,7 @@ raise_with_op.print_thunk_trace = False
class Linker(object):
"""WRITEME"""
def clone(allow_gc=undef):
def clone(self, allow_gc=undef):
new = copy(self)
if allow_gc is not undef:
new.allow_gc = allow_gc
......
......@@ -38,9 +38,10 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('scan.allow_gc',
"Allow/disallow gc inside of Scan",
BoolParam(False))
"Allow/disallow gc inside of Scan (default: config.allow_gc)",
BoolParam(lambda: config.allow_gc))
class Scan(PureOp):
......@@ -110,7 +111,7 @@ class Scan(PureOp):
isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker)
linker=mode_instance.linker.clone(allow_gc=allow_gc))
compile.profilemode.prof_mode_instance_to_print.append(
mode_instance)
self.mode_instance = mode_instance
......@@ -119,10 +120,9 @@ class Scan(PureOp):
else:
self.mode_instance.message = "Scan sub profile"
else:
mode_instance = mode_instance.__type__(
self.mode_instance = type(mode_instance)(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker.clone(allow_gc=allow_gc))
self.mode_instance = mode_instance
linker=mode_instance.linker.clone(allow_gc=allow_gc))
if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn'
......@@ -632,10 +632,17 @@ class Scan(PureOp):
p = self.execute
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
# Big ugly hack since we can't get the real value of allow_gc
# for the englobing function.
allow_gc = config.allow_gc
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
allow_gc=allow_gc):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
if allow_gc:
self.fn.free()
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论