提交 fc809d59 authored 作者: --global's avatar --global

Add theano flag to control output preallocation in scan

上级 e76fb8d9
......@@ -86,6 +86,11 @@ AddConfigVar('scan.allow_gc',
"Allow/disallow gc inside of Scan (default: False)",
BoolParam(False))
AddConfigVar('scan.allow_output_prealloc',
"Allow/disallow memory preallocation for outputs inside of scan "
"(default: False)",
BoolParam(True))
class Scan(PureOp):
def __init__(self,
......@@ -185,12 +190,14 @@ class Scan(PureOp):
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.linker.clone(allow_gc=self.allow_gc))
# Now that scan has its mode instance, we activate optimization
# Now that scan has its mode instance, if memory pre-allocation is
# activated for the outputs, we activate the optimization
# add_no_output_from_inplace in this mode instance. This will prevent
# Scan from producing outputs by means of inplace operations and
# therefore allow it to pre-allocate memory storage for the outputs,
# avoiding needless copies.
self.mode_instance = self.mode_instance.including(
if theano.config.scan.allow_output_prealloc:
self.mode_instance = self.mode_instance.including(
"add_no_output_from_inplace")
if not hasattr(self, 'name') or self.name is None:
......@@ -717,7 +724,8 @@ class Scan(PureOp):
self.n_sit_sot +
self.n_nit_sot)
wrapped_inputs = [Param(x, borrow=False) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in
borrow_outputs = theano.config.scan.allow_output_prealloc
wrapped_outputs = [Out(x, borrow=borrow_outputs) for x in
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
profile = None
......
......@@ -1236,12 +1236,18 @@ class ScanSaveMem(gof.Optimizer):
# tap needed don't occupy the sample place in the
# circular buffer. For now, this only needs to be done
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism).
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
prealloc_outs = theano.config.scan.allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot +
node.op.n_mit_sot +
node.op.n_sit_sot - 1)
if (i >= first_mitsot_idx and i <= last_sitsot_idx):
preallocable_output = (i >= first_mitsot_idx and
i <= last_sitsot_idx)
if (prealloc_outs and preallocable_output):
pval = select_max(nw_steps - start + init_l[i],
init_l[i] + 1)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论