提交 14b6c4cb authored 作者: nouiz's avatar nouiz

Merge pull request #316 from pascanur/fix_scan_savemem

Fix scan savemem
...@@ -1593,7 +1593,7 @@ class Scan(PureOp): ...@@ -1593,7 +1593,7 @@ class Scan(PureOp):
# Outputs # Outputs
n_mit_mot_outs = int(numpy.sum([len(x) for x in n_mit_mot_outs = int(numpy.sum([len(x) for x in
self.mit_mot_out_slices])) self.mit_mot_out_slices]))
info['n_mit_mot_outs'] = n_mit_mot_outs info['n_mit_mot_outs'] = n_mit_mot_outs*2
b = 0 b = 0
e = n_mit_mot_outs e = n_mit_mot_outs
inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e] inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e]
......
...@@ -640,13 +640,17 @@ class ScanSaveMem(gof.Optimizer): ...@@ -640,13 +640,17 @@ class ScanSaveMem(gof.Optimizer):
isinstance(nw_inputs[offset + idx].owner.op, isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor)): tensor.IncSubtensor)):
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
val = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i])
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i])) tensor.maximum(val - initl, 0))
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand(_nw_input, tmp) nw_input = scan_utils.expand(_nw_input, tmp)
else: else:
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp = tensor.as_tensor_variable(val)
tensor.as_tensor_variable(val)) initl = tensor.as_tensor_variable(init_l[i])
tmp = tensor.maximum(tmp, initl)
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp)
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset + idx][:tmp] nw_input = nw_inputs[offset + idx][:tmp]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论