提交 2601a061 authored 作者: Roy Xue's avatar Roy Xue

Merge pull request #4 from nouiz/GSoC2014_part2

a few fixes
...@@ -722,7 +722,7 @@ class ProfileStats(object): ...@@ -722,7 +722,7 @@ class ProfileStats(object):
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
viewed_by[origin].remove(ins) viewed_by[origin].remove(ins)
if not viewed_by[origin]: if not viewed_by[origin] and origin not in fgraph.inputs:
running_memory_size -= var_mem[origin] running_memory_size -= var_mem[origin]
else: else:
# ins is viewed_by something else, so its memory isn't freed # ins is viewed_by something else, so its memory isn't freed
...@@ -736,6 +736,7 @@ class ProfileStats(object): ...@@ -736,6 +736,7 @@ class ProfileStats(object):
mem_count = 0 mem_count = 0
max_mem_count = 0 max_mem_count = 0
mem_bound = numpy.inf mem_bound = numpy.inf
# This take only the inputs/outputs dependencies.
dependencies = fgraph.profile.dependencies dependencies = fgraph.profile.dependencies
# Initial compute_map which is used to check if a node is valid # Initial compute_map which is used to check if a node is valid
...@@ -773,7 +774,7 @@ class ProfileStats(object): ...@@ -773,7 +774,7 @@ class ProfileStats(object):
viewed_by = {}# {var1: [vars that view var1]} viewed_by = {}# {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for var in fgraph.apply_nodes: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
view_of = {}# {var1: original var viewed by var1} view_of = {}# {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
...@@ -833,7 +834,7 @@ class ProfileStats(object): ...@@ -833,7 +834,7 @@ class ProfileStats(object):
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
viewed_by[origin].remove(ins) viewed_by[origin].remove(ins)
if not viewed_by[origin]: if not viewed_by[origin] and origin not in fgraph.inputs:
mem_freed += var_mem[origin] mem_freed += var_mem[origin]
else: else:
# ins is viewed_by something else, so its memory isn't freed # ins is viewed_by something else, so its memory isn't freed
......
...@@ -44,9 +44,10 @@ def test_profiling(): ...@@ -44,9 +44,10 @@ def test_profiling():
# regression testing for future algo speed up # regression testing for future algo speed up
the_string = buf.getvalue() the_string = buf.getvalue()
# assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string lines1 = [l for l in the_string.split("\n") if "Max if linker" in l]
# assert "Minimum peak from all valid apply node order is 8192KB" in the_string lines2 = [l for l in the_string.split("\n") if "Minimum peak" in l]
print the_string assert "Max if linker=cvm(default): 8208KB (16400KB)" in the_string, (lines1, lines2)
assert "Minimum peak from all valid apply node order is 8192KB" in the_string, (lines1, lines2)
finally: finally:
theano.config.profile = config1 theano.config.profile = config1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论