提交 df48ecec authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added a missing corner case to scan save mem optimization ( when there is

nothing else to be done except removing orphane outputs )
上级 233e9061
...@@ -350,6 +350,9 @@ class ScanSaveMem(Optimizer): ...@@ -350,6 +350,9 @@ class ScanSaveMem(Optimizer):
store_steps[i] = pval store_steps[i] = pval
flag_store = True flag_store = True
orphane_outs = [ i for i,x in enumerate(store_steps)
if (type(x) is int) and (x<0) ]
flag_store = flag_store or (len(orphane_outs) > 0 )
# 3. is there anything to change ? # 3. is there anything to change ?
if (flag_store or global_nsteps is not None): if (flag_store or global_nsteps is not None):
# 3.1 initialize inputs for the new scan # 3.1 initialize inputs for the new scan
...@@ -358,8 +361,6 @@ class ScanSaveMem(Optimizer): ...@@ -358,8 +361,6 @@ class ScanSaveMem(Optimizer):
nw_inputs[0] = nw_steps nw_inputs[0] = nw_steps
# 3.2 check orphane outputs to see if we can eliminate any # 3.2 check orphane outputs to see if we can eliminate any
orphane_outs = [ i for i,x in enumerate(store_steps)
if (type(x) is int) and (x < 0) ]
required,not_required = \ required,not_required = \
scan_utils.scan_can_remove_outs(node.op scan_utils.scan_can_remove_outs(node.op
, orphane_outs) , orphane_outs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论