提交 17bc0839 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed inplace optimization

In order for scan to run correctly inplace, it needs that none of the initial states are the same memory buffer.
上级 667104c4
...@@ -292,10 +292,24 @@ def scan_make_inplace(node): ...@@ -292,10 +292,24 @@ def scan_make_inplace(node):
(not op.info['inplace']) ): (not op.info['inplace']) ):
info = op.info.copy() info = op.info.copy()
info['inplace'] = True info['inplace'] = True
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1+op.n_seqs]
ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan( op.inputs new_op = scan_op.Scan( op.inputs
, op.outputs , op.outputs
, info) , info)
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*inputs).outputs
return False return False
optdb.register( 'scanOp_make_inplace' optdb.register( 'scanOp_make_inplace'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论