提交 db6e7b56 authored 作者: --global's avatar --global

Modify ScanSaveMem to keep a buffer big enough for the memory reuse feature

上级 4ae75bc3
...@@ -675,7 +675,7 @@ class Scan(PureOp): ...@@ -675,7 +675,7 @@ class Scan(PureOp):
self.n_sit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
wrapped_inputs = [Param(x, borrow=False) for x in self.inputs] wrapped_inputs = [Param(x, borrow=False) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=(x not in self.inputs)) for x in wrapped_outputs = [Out(x, borrow=True) for x in
self.outputs[:slices]] self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:] wrapped_outputs += self.outputs[slices:]
profile = None profile = None
......
...@@ -1228,8 +1228,32 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1228,8 +1228,32 @@ class ScanSaveMem(gof.Optimizer):
if start == 0 or store_steps[i] == 0: if start == 0 or store_steps[i] == 0:
store_steps[i] = 0 store_steps[i] = 0
else: else:
pval = select_max(nw_steps - start + init_l[i], # The "+ 1" is because if the memory pre-allocation
init_l[i]) # mechanism used to in the Scan op to reduce overhead.
# To prevent aliasing between the inputs and outputs
# of recurrent states, it requires that the buffer be
# large enough to that, the new state and the oldest
# tap needed don't occupy the sample place in the
# circular buffer. For now, this only needs to be done
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# the inner function has more then one output
# (otherwise, there is no risk of aliasing because
# once the output is computed, the oldest tap can
# safely be overwritten).
first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (node.op.n_mit_mot +
node.op.n_mit_sot +
node.op.n_sit_sot - 1)
if (i >= first_mitsot_idx and i <= last_sitsot_idx and
len(node.op.outputs) > 1):
pval = select_max(nw_steps - start + init_l[i],
init_l[i] + 1)
else:
pval = select_max(nw_steps - start + init_l[i],
init_l[i])
if store_steps[i] != -1: if store_steps[i] != -1:
pval = select_max(pval, store_steps[i]) pval = select_max(pval, store_steps[i])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论