提交 290e2b40 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix to the scan's save memory optimizer

上级 9517a606
...@@ -640,13 +640,17 @@ class ScanSaveMem(gof.Optimizer): ...@@ -640,13 +640,17 @@ class ScanSaveMem(gof.Optimizer):
isinstance(nw_inputs[offset + idx].owner.op, isinstance(nw_inputs[offset + idx].owner.op,
tensor.IncSubtensor)): tensor.IncSubtensor)):
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
val = tensor.as_tensor_variable(val)
initl = tensor.as_tensor_variable(init_l[i])
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i])) tensor.maximum(val - initl, 0))
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand(_nw_input, tmp) nw_input = scan_utils.expand(_nw_input, tmp)
else: else:
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp = tensor.as_tensor_variable(val)
tensor.as_tensor_variable(val)) initl = tensor.as_tensor_variable(init_l[i])
tmp = tensor.maximum(tmp, initl)
tmp = pre_greedy_local_optimizer(list_opt_slice, tmp)
tmp = pre_constant_merge([tmp])[0] tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset + idx][:tmp] nw_input = nw_inputs[offset + idx][:tmp]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论