提交 b6c5be0f authored 作者: ChienliMa's avatar ChienliMa

Move '__swapSV' inline.

上级 04f50909
...@@ -608,9 +608,41 @@ class Function(object): ...@@ -608,9 +608,41 @@ class Function(object):
# swap SharedVariable # swap SharedVariable
if swap is not None: if swap is not None:
self.__swapSV(swap, ins, fg_cpy) def checkSV(sv_ori, sv_rpl):
# the name of SV we swapped """
swapped_sv = swap.keys() Assert two SharedVariable follow some restirctions:
1. same type
2. same shape or dim?
"""
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type)
exist_names = [i.variable.name for i in ins]
swap_names = swap.keys()
# Check if given names exist
for name in swap_names:
if name not in exist_names:
warnings.warn("Given name: %s wasn't found" % (name))
# Swap SharedVairable in fgraph and ins
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
assert i.variable.name == in_v.name
if in_v.name in swap_names:
checkSV(in_v, swap[in_v.name])
# In the fgraph we use the cloned SharedVariable
swap_sv = swap[in_v.name].clone()
# Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
new_storage_map = {} new_storage_map = {}
...@@ -644,7 +676,7 @@ class Function(object): ...@@ -644,7 +676,7 @@ class Function(object):
f_cpy.input_storage): f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant) is_const = isinstance(in_ori.variable, theano.tensor.Constant)
# In instances' name default to vairables' name # In instances' name default to vairables' name
swapped = swap is not None and in_ori.name in swapped_sv swapped = swap is not None and in_ori.name in swap_names
if (is_const or not in_ori.mutable) and not swapped: if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data cpy.data = ori.data
...@@ -653,7 +685,7 @@ class Function(object): ...@@ -653,7 +685,7 @@ class Function(object):
# Reconstruct Function.finder. # Reconstruct Function.finder.
# Function.value and Function.data work # Function.value and Function.data work
for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs): for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs):
swapped = swap is not None and ori.name in swapped_sv swapped = swap is not None and ori.name in swap_names
if not swapped: if not swapped:
f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable) f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable)
else: else:
...@@ -661,55 +693,6 @@ class Function(object): ...@@ -661,55 +693,6 @@ class Function(object):
return f_cpy return f_cpy
def __swapSV(self, swap, ins, fg_cpy):
"""
Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph.
--------------------
Params:
swap -- { dict } The same as docs in copy()
ins -- { list } List of In instances to be modified.
fg_cpy -- { FounctionGraph } Copied FunctionGraph.
--------------------
Returns:
None
"""
def checkSV(sv_ori, sv_rpl):
"""
Assert two SharedVariable follow some restirctions:
1. same type
2. same shape or dim?
"""
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type)
exist_names = [i.variable.name for i in ins]
swap_names = swap.keys()
# Check if given names exist
for name in swap_names:
if name not in exist_names:
warnings.warn("Given name: %s wasn't found" % (name))
# Swap SharedVairable in fgraph and ins
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
assert i.variable.name == in_v.name
if in_v.name in swap_names:
# In the fgraph we use the cloned SharedVariable
swap_sv = swap[in_v.name].clone()
checkSV(in_v, swap_sv)
# Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
t0 = time.time() t0 = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论