提交 6ce06844 authored 作者: Roy Xue's avatar Roy Xue

Undone speed up part

上级 0df71f39
......@@ -24,7 +24,7 @@ from collections import defaultdict
import numpy
import theano
from theano.gof import Constant
from theano.gof import graph
from theano.configparser import AddConfigVar, BoolParam, IntParam
......@@ -694,12 +694,14 @@ class ProfileStats(object):
max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed
mem_list = []
order_index = 0
order = []
node_list = list(node_list)
min_mem = sys.maxint
current_mem = 0
check_len = len(node_list)
compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid
......@@ -716,7 +718,7 @@ class ProfileStats(object):
outputs = node.outputs
deps = inputs + node.destroy_dependencies
for node in deps:
if isinstance(node, Constant):
if isinstance(node, graph.Constant):
compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in inputs)
computed_outs = all(compute_map[v][0] for v in outputs)
......@@ -726,6 +728,8 @@ class ProfileStats(object):
else:
return False
maybe_executed = set()
def min_memory_generator(node_list):
'''
enumerate all valid orders for the list of nodes in node_list
......@@ -736,19 +740,28 @@ class ProfileStats(object):
:param compute_map: simulate the node execution steps to update compute_map
'''
global maybe_executed
for i in range(len(node_list)):
v = node_list[i:i+1]
if check_node_state(v[0]):
for node in v[0].outputs:
compute_map[node][0] = 1
if len(node_list) == 1:
yield v
else:
rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest):
yield v+p
for node in v[0].outputs:
compute_map[node][0] = 0
if len(node_list) == check_len or v[0] in maybe_executed:
if check_node_state(v[0]):
for node in v[0].outputs:
compute_map[node][0] = 1
for c, _ in node.clients:
if c == "output":
pass
else:
maybe_executed.add(node)
if len(node_list) == 1:
yield v
maybe_executed = set()
else:
rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest):
yield v+p
for node in v[0].outputs:
compute_map[node][0] = 0
min_order = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论