提交 09fe88b9 authored 作者: ChienliMa's avatar ChienliMa

pep8 coding style fix and return 'del a' line in test_leak2

上级 0c1cdd89
...@@ -597,14 +597,14 @@ class Function(object): ...@@ -597,14 +597,14 @@ class Function(object):
update_i = len(outs) update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs): for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var i.variable = in_var
if not delete_updates and i.update != None: if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i] i.update = fg_cpy.outputs[update_i]
update_i += 1 update_i += 1
else: else:
i.update = None i.update = None
# swap SharedVariable # swap SharedVariable
if swap != None: if swap is not None:
self.__swapSV(swap, ins, fg_cpy) self.__swapSV(swap, ins, fg_cpy)
# the name of SV we swapped # the name of SV we swapped
swapped_sv = swap.keys() swapped_sv = swap.keys()
...@@ -627,7 +627,7 @@ class Function(object): ...@@ -627,7 +627,7 @@ class Function(object):
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
input_storage = [ i.value for i in ins ] input_storage = [i.value for i in ins]
# reinitialize new maker and create new function # reinitialize new maker and create new function
f_cpy = maker.__class__(inputs=ins, outputs=outs, f_cpy = maker.__class__(inputs=ins, outputs=outs,
fgraph=fg_cpy, fgraph=fg_cpy,
...@@ -643,7 +643,7 @@ class Function(object): ...@@ -643,7 +643,7 @@ class Function(object):
f_cpy.input_storage): f_cpy.input_storage):
is_const = isinstance(in_ori.variable, theano.tensor.Constant) is_const = isinstance(in_ori.variable, theano.tensor.Constant)
# In instances' name default to vairables' name # In instances' name default to vairables' name
swapped = swap != None and in_ori.name in swapped_sv swapped = swap is not None and in_ori.name in swapped_sv
if (is_const or not in_ori.mutable) and not swapped: if (is_const or not in_ori.mutable) and not swapped:
cpy.data = ori.data cpy.data = ori.data
...@@ -652,7 +652,7 @@ class Function(object): ...@@ -652,7 +652,7 @@ class Function(object):
# Reconstruct Function.finder. # Reconstruct Function.finder.
# Function.value and Function.data work # Function.value and Function.data work
for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs): for ori, cpy in zip(maker.inputs, f_cpy.maker.inputs):
swapped = swap != None and ori.name in swapped_sv swapped = swap is not None and ori.name in swapped_sv
if not swapped: if not swapped:
f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable) f_cpy.finder[ori.variable] = f_cpy.finder.pop(cpy.variable)
else: else:
...@@ -673,7 +673,7 @@ class Function(object): ...@@ -673,7 +673,7 @@ class Function(object):
Returns: Returns:
None None
""" """
def checkSV( sv_ori, sv_rpl ): def checkSV(sv_ori, sv_rpl):
""" """
Assert two SharedVariable follow some restirctions: Assert two SharedVariable follow some restirctions:
1. same type 1. same type
...@@ -682,7 +682,7 @@ class Function(object): ...@@ -682,7 +682,7 @@ class Function(object):
assert sv_ori.type == sv_rpl.type, ( assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one", "Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type, "Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type ) "Type of original SharedVariable:", sv_ori.type)
exist_names = [i.variable.name for i in ins] exist_names = [i.variable.name for i in ins]
swap_names = swap.keys() swap_names = swap.keys()
...@@ -690,7 +690,7 @@ class Function(object): ...@@ -690,7 +690,7 @@ class Function(object):
# Check if given names exist # Check if given names exist
for name in swap_names: for name in swap_names:
if name not in exist_names: if name not in exist_names:
warnings.warn( "Given name: %s wasn't found" % (name) ) warnings.warn("Given name: %s wasn't found" % (name))
# Swap SharedVairable in fgraph and ins # Swap SharedVairable in fgraph and ins
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)): for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
...@@ -699,11 +699,11 @@ class Function(object): ...@@ -699,11 +699,11 @@ class Function(object):
if in_v.name in swap_names: if in_v.name in swap_names:
# In the fgraph we use the cloned SharedVariable # In the fgraph we use the cloned SharedVariable
swap_sv = swap[in_v.name].clone() swap_sv = swap[in_v.name].clone()
checkSV( in_v, swap_sv) checkSV(in_v, swap_sv)
# Swap SharedVariable in fgraph # Swap SharedVariable in fgraph
fg_cpy.inputs[index] = swap_sv fg_cpy.inputs[index] = swap_sv
fg_cpy.replace( in_v, swap_sv, reason="Swap SV") fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# swap variable and value of In instances # swap variable and value of In instances
i.variable = swap_sv i.variable = swap_sv
......
...@@ -230,7 +230,7 @@ if run_memory_usage_tests: ...@@ -230,7 +230,7 @@ if run_memory_usage_tests:
a = cuda.CudaNdarray(n) a = cuda.CudaNdarray(n)
a.sum() a.sum()
assert c == sys.getrefcount(n) assert c == sys.getrefcount(n)
del a
if not i % 1000: if not i % 1000:
print('.', end=' ') print('.', end=' ')
print(gc.collect(), end=' ') print(gc.collect(), end=' ')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论