提交 d0bb8089 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Propagate the allow_gc flag in scan optimizations.

I hope I got all the spots. This is insane.
上级 b2bc4f62
...@@ -967,6 +967,8 @@ def scan(fn, ...@@ -967,6 +967,8 @@ def scan(fn,
## ##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
if allow_gc is None:
allow_gc = config.scan.allow_gc
info = OrderedDict() info = OrderedDict()
info['tap_array'] = tap_array info['tap_array'] = tap_array
...@@ -985,8 +987,9 @@ def scan(fn, ...@@ -985,8 +987,9 @@ def scan(fn,
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
info['allow_gc'] = allow_gc
local_op = scan_op.Scan(inner_inputs, new_outs, info, allow_gc=allow_gc) local_op = scan_op.Scan(inner_inputs, new_outs, info)
## ##
### Step 8. Compute the outputs using the scan op ### Step 8. Compute the outputs using the scan op
......
...@@ -49,7 +49,6 @@ class Scan(PureOp): ...@@ -49,7 +49,6 @@ class Scan(PureOp):
inputs, inputs,
outputs, outputs,
info, info,
allow_gc=None
): ):
""" """
:param inputs: inputs of the inner function of scan :param inputs: inputs of the inner function of scan
...@@ -58,13 +57,9 @@ class Scan(PureOp): ...@@ -58,13 +57,9 @@ class Scan(PureOp):
the scan op (like number of different types of the scan op (like number of different types of
arguments, name, mode, if it should run on GPU or arguments, name, mode, if it should run on GPU or
not, etc.) not, etc.)
:param allow_gc: Use the gc in the inner function or not
(independant of the outer function)
""" """
if 'gpua' not in info: if 'gpua' not in info:
info['gpua'] = False info['gpua'] = False
if allow_gc is None:
allow_gc = config.scan.allow_gc
# adding properties into self # adding properties into self
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
...@@ -73,7 +68,6 @@ class Scan(PureOp): ...@@ -73,7 +68,6 @@ class Scan(PureOp):
# since info contains all tunable parameters of the op, so for two # since info contains all tunable parameters of the op, so for two
# scan to be equal this tunable parameters should be the same # scan to be equal this tunable parameters should be the same
self.info = info self.info = info
# build a list of output types for any Apply node using this op. # build a list of output types for any Apply node using this op.
self.output_types = [] self.output_types = []
idx = 0 idx = 0
...@@ -111,7 +105,7 @@ class Scan(PureOp): ...@@ -111,7 +105,7 @@ class Scan(PureOp):
isinstance(mode_instance, compile.profilemode.ProfileMode)): isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode( mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer, optimizer=mode_instance.provided_optimizer,
linker=mode_instance.linker.clone(allow_gc=allow_gc)) linker=mode_instance.linker.clone(allow_gc=self.allow_gc))
compile.profilemode.prof_mode_instance_to_print.append( compile.profilemode.prof_mode_instance_to_print.append(
mode_instance) mode_instance)
self.mode_instance = mode_instance self.mode_instance = mode_instance
...@@ -122,7 +116,7 @@ class Scan(PureOp): ...@@ -122,7 +116,7 @@ class Scan(PureOp):
else: else:
self.mode_instance = type(mode_instance)( self.mode_instance = type(mode_instance)(
optimizer=mode_instance.provided_optimizer, optimizer=mode_instance.provided_optimizer,
linker=mode_instance.linker.clone(allow_gc=allow_gc)) linker=mode_instance.linker.clone(allow_gc=self.allow_gc))
if not hasattr(self, 'name') or self.name is None: if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn' self.name = 'scan_fn'
...@@ -456,15 +450,11 @@ class Scan(PureOp): ...@@ -456,15 +450,11 @@ class Scan(PureOp):
if self_in.type != other_in.type: if self_in.type != other_in.type:
return False return False
if not scan_utils.equal_computations(self.outputs, return scan_utils.equal_computations(self.outputs,
other.outputs, other.outputs,
self.inputs, self.inputs,
other.inputs): other.inputs)
return False
# If they do, then they need to match in other small details
# like name, mode, etc.
return True
def __str__(self): def __str__(self):
if self.gpu: if self.gpu:
......
...@@ -1500,6 +1500,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1500,6 +1500,7 @@ class ScanMerge(gof.Optimizer):
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = nodes[0].op.profile info['profile'] = nodes[0].op.profile
info['allow_gc'] = nodes[0].op.allow_gc
# We keep the inner_ins and inner_outs of each original node separated. # We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone, # To be able to recombine them in the right order after the clone,
......
...@@ -654,9 +654,11 @@ def compress_outs(op, not_required, inputs): ...@@ -654,9 +654,11 @@ def compress_outs(op, not_required, inputs):
info['truncate_gradient'] = op.info['truncate_gradient'] info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name'] info['name'] = op.info['name']
info['gpu'] = op.info['gpu'] info['gpu'] = op.info['gpu']
info['gpua'] = op.info['gpua']
info['mode'] = op.info['mode'] info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while'] info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile'] info['profile'] = op.info['profile']
info['allow_gc'] = op.info['allow_gc']
op_inputs = op.inputs[:op.n_seqs] op_inputs = op.inputs[:op.n_seqs]
op_outputs = [] op_outputs = []
...@@ -919,7 +921,7 @@ class scan_args(object): ...@@ -919,7 +921,7 @@ class scan_args(object):
self.other_info = OrderedDict() self.other_info = OrderedDict()
for k in ('truncate_gradient', 'name', 'mode', 'destroy_map', for k in ('truncate_gradient', 'name', 'mode', 'destroy_map',
'gpu', 'as_while', 'profile'): 'gpu', 'gpua', 'as_while', 'profile', 'allow_gc'):
if k in info: if k in info:
self.other_info[k] = info[k] self.other_info[k] = info[k]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论