提交 6ce06844 authored 作者: Roy Xue's avatar Roy Xue

Undone speed up part

上级 0df71f39
...@@ -24,7 +24,7 @@ from collections import defaultdict ...@@ -24,7 +24,7 @@ from collections import defaultdict
import numpy import numpy
import theano import theano
from theano.gof import Constant from theano.gof import graph
from theano.configparser import AddConfigVar, BoolParam, IntParam from theano.configparser import AddConfigVar, BoolParam, IntParam
...@@ -694,12 +694,14 @@ class ProfileStats(object): ...@@ -694,12 +694,14 @@ class ProfileStats(object):
max_minimum_peak = 0 max_minimum_peak = 0
def count_minimum_peak(node_list, fgraph, nodes_mem): def count_minimum_peak(node_list, fgraph, nodes_mem):
global maybe_executed
mem_list = [] mem_list = []
order_index = 0 order_index = 0
order = [] order = []
node_list = list(node_list) node_list = list(node_list)
min_mem = sys.maxint min_mem = sys.maxint
current_mem = 0 current_mem = 0
check_len = len(node_list)
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid # compute_map use to check if a node is valid
...@@ -716,7 +718,7 @@ class ProfileStats(object): ...@@ -716,7 +718,7 @@ class ProfileStats(object):
outputs = node.outputs outputs = node.outputs
deps = inputs + node.destroy_dependencies deps = inputs + node.destroy_dependencies
for node in deps: for node in deps:
if isinstance(node, Constant): if isinstance(node, graph.Constant):
compute_map[node][0] = 1 compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in inputs) computed_ins = all(compute_map[v][0] for v in inputs)
computed_outs = all(compute_map[v][0] for v in outputs) computed_outs = all(compute_map[v][0] for v in outputs)
...@@ -726,6 +728,8 @@ class ProfileStats(object): ...@@ -726,6 +728,8 @@ class ProfileStats(object):
else: else:
return False return False
maybe_executed = set()
def min_memory_generator(node_list): def min_memory_generator(node_list):
''' '''
enumerate all valid orders for the list of nodes in node_list enumerate all valid orders for the list of nodes in node_list
...@@ -736,19 +740,28 @@ class ProfileStats(object): ...@@ -736,19 +740,28 @@ class ProfileStats(object):
:param compute_map: simulate the node execution steps to update compute_map :param compute_map: simulate the node execution steps to update compute_map
''' '''
global maybe_executed
for i in range(len(node_list)): for i in range(len(node_list)):
v = node_list[i:i+1] v = node_list[i:i+1]
if check_node_state(v[0]): if len(node_list) == check_len or v[0] in maybe_executed:
for node in v[0].outputs: if check_node_state(v[0]):
compute_map[node][0] = 1 for node in v[0].outputs:
if len(node_list) == 1: compute_map[node][0] = 1
yield v for c, _ in node.clients:
else: if c == "output":
rest = node_list[ :i] + node_list[i+1: ] pass
for p in min_memory_generator(rest): else:
yield v+p maybe_executed.add(node)
for node in v[0].outputs: if len(node_list) == 1:
compute_map[node][0] = 0 yield v
maybe_executed = set()
else:
rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest):
yield v+p
for node in v[0].outputs:
compute_map[node][0] = 0
min_order = [] min_order = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论