提交 d24c60f4 authored 作者: Roy Xue's avatar Roy Xue

update

上级 6ce06844
...@@ -729,6 +729,9 @@ class ProfileStats(object): ...@@ -729,6 +729,9 @@ class ProfileStats(object):
return False return False
maybe_executed = set() maybe_executed = set()
for var in fgraph.inputs:
for c, _ in var.clients:
maybe_executed.add(c)
def min_memory_generator(node_list): def min_memory_generator(node_list):
''' '''
...@@ -742,20 +745,18 @@ class ProfileStats(object): ...@@ -742,20 +745,18 @@ class ProfileStats(object):
''' '''
global maybe_executed global maybe_executed
for i in range(len(node_list)): for i in range(len(node_list)):
v = node_list[i:i+1] v = node_list[i:i+1]
if len(node_list) == check_len or v[0] in maybe_executed: if v[0] in maybe_executed:
if check_node_state(v[0]): if check_node_state(v[0]):
for node in v[0].outputs: for node in v[0].outputs:
compute_map[node][0] = 1 compute_map[node][0] = 1
for c, _ in node.clients: for c, _ in node.clients:
if c == "output": if c != "output":
pass maybe_executed.add(c)
else:
maybe_executed.add(node)
if len(node_list) == 1: if len(node_list) == 1:
yield v yield v
maybe_executed = set()
else: else:
rest = node_list[ :i] + node_list[i+1: ] rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest): for p in min_memory_generator(rest):
...@@ -765,6 +766,8 @@ class ProfileStats(object): ...@@ -765,6 +766,8 @@ class ProfileStats(object):
min_order = [] min_order = []
print node_list
for order in min_memory_generator(node_list): for order in min_memory_generator(node_list):
post_thunk_old_storage = [] post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order) computed, last_user = theano.gof.link.gc_helper(order)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论