提交 81d99756 authored 作者: ChienliMa's avatar ChienliMa

Only seperate input storage of mutable SharedVariable

上级 de6cd0af
...@@ -590,7 +590,8 @@ class Function(object): ...@@ -590,7 +590,8 @@ class Function(object):
# so that they have different storage as their value # so that they have different storage as their value
ins = [copy.copy(input) for input in maker.inputs] ins = [copy.copy(input) for input in maker.inputs]
for in_cpy, in_ori in zip(ins, maker.inputs): for in_cpy, in_ori in zip(ins, maker.inputs):
in_cpy.value = copy.deepcopy(in_ori.value) if in_ori.mutable:
in_cpy.value = copy.deepcopy(in_ori.value)
# Delete update output in fgraph and updates In instances if needed # Delete update output in fgraph and updates In instances if needed
if delete_updates: if delete_updates:
...@@ -613,10 +614,10 @@ class Function(object): ...@@ -613,10 +614,10 @@ class Function(object):
for out_ori, out_cpy in zip(maker.outputs, outs): for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow out_cpy.borrow = out_ori.borrow
# Delete update if needed
update_i = len(outs) update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs): for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var i.variable = in_var
# Delete update if needed
if not delete_updates and i.update is not None: if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i] i.update = fg_cpy.outputs[update_i]
update_i += 1 update_i += 1
...@@ -634,7 +635,7 @@ class Function(object): ...@@ -634,7 +635,7 @@ class Function(object):
raise ValueError("SharedVariable: %s not found" % raise ValueError("SharedVariable: %s not found" %
(sv.name)) (sv.name))
# Swap SharedVariable in fgraph and ins # Swap SharedVariable in fgraph and In instances
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)): for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
# Variables in maker.inputs are defined by user, therefore we # Variables in maker.inputs are defined by user, therefore we
# use them to make comparision and do the mapping. # use them to make comparision and do the mapping.
...@@ -703,8 +704,9 @@ class Function(object): ...@@ -703,8 +704,9 @@ class Function(object):
# Share immutable ShareVariable and constant input's storage # Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.variable in swap_svs_ori swapped = swap is not None and in_ori.variable in swap_svs_ori
is_const = isinstance(in_ori.variable, theano.tensor.Constant) # Using the original storage if SharedVariable will not be updated
if (is_const or not in_ori.mutable) and not swapped: # and is not swapped
if not in_ori.mutable and not swapped:
cpy.data = ori.data cpy.data = ori.data
in_cpy.value = in_ori.value in_cpy.value = in_ori.value
......
...@@ -272,9 +272,7 @@ class T_function(unittest.TestCase): ...@@ -272,9 +272,7 @@ class T_function(unittest.TestCase):
for (input, _1, _2), here, there in zip(ori.indices, for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage, ori.input_storage,
cpy.input_storage): cpy.input_storage):
if not input.mutable: if input.mutable:
self.assertTrue(here.data is there.data)
else:
self.assertFalse(here.data is there.data) self.assertFalse(here.data is there.data)
def test_swap_SharedVariable(self): def test_swap_SharedVariable(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论