提交 f303478f authored 作者: ChienliMa's avatar ChienliMa

Delete update should be done after swap sv

上级 876507f6
...@@ -611,16 +611,6 @@ class Function(object): ...@@ -611,16 +611,6 @@ class Function(object):
for out_ori, out_cpy in zip(maker.outputs, outs): for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow out_cpy.borrow = out_ori.borrow
# Delete update if needed
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
else:
i.update = None
# swap SharedVariable # swap SharedVariable
if swap is not None: if swap is not None:
swap_svs_ori = swap.keys() swap_svs_ori = swap.keys()
...@@ -650,9 +640,18 @@ class Function(object): ...@@ -650,9 +640,18 @@ class Function(object):
swap_sv = swap_sv.clone() swap_sv = swap_sv.clone()
# Swap SharedVariable in fgraph # Swap SharedVariable in fgraph
# fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV") fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# Delete update if needed
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
else:
i.update = None
# Construct new storage_map that map new variable to old storage, # Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one # so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
......
...@@ -313,7 +313,7 @@ class T_function(unittest.TestCase): ...@@ -313,7 +313,7 @@ class T_function(unittest.TestCase):
elif second_time: elif second_time:
# doule update for sharedvariable # doule update for sharedvariable
assert m.get_value() == 12 assert m.get_value() == 12
assert z.get_value() == 3 assert z.get_value() == 4
assert z_rpl.get_value() == 8 assert z_rpl.get_value() == 8
assert y_rpl.get_value() == 3 assert y_rpl.get_value() == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论