提交 d173f8cb authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2106 from RoyXue/GSoC2014_part2

Algorithm Speed Up
...@@ -785,28 +785,18 @@ class ProfileStats(object): ...@@ -785,28 +785,18 @@ class ProfileStats(object):
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs: for var in fgraph.inputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
for var in node_list:
def check_node_state(node): for val in var.inputs:
""" if isinstance(val, graph.Constant):
Check if an Apply node is valid(has inputs). compute_map[val][0] = 1
:param node: Apply Node
"""
inputs = node.inputs
outputs = node.outputs
deps = inputs + node.destroy_dependencies
# TODO: Move at compute_map creation to speed things up.
for node in inputs:
if isinstance(node, graph.Constant):
compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in deps)
return computed_ins
# Initial executable_nodes # Initial executable_nodes
executable_nodes = set() executable_nodes = set()
for var in fgraph.inputs: for var in fgraph.inputs:
for c, _ in var.clients: for c, _ in var.clients:
if c != "output" and check_node_state(c): if c != "output":
deps = c.inputs + c.destroy_dependencies
if all(compute_map[v][0] for v in deps):
executable_nodes.add(c) executable_nodes.add(c)
def min_memory_generator(executable_nodes, viewed_by, view_of): def min_memory_generator(executable_nodes, viewed_by, view_of):
...@@ -826,13 +816,12 @@ class ProfileStats(object): ...@@ -826,13 +816,12 @@ class ProfileStats(object):
if max_mem_count > mem_bound: if max_mem_count > mem_bound:
continue continue
view_of_temp = view_of.copy() viewof_change = []
# We don't want a shallow copy, but we don't want # Use to track view_of changes
# a deep copy. So this do a "middle" copy, where
# we copy the dict and the list, but not the var viewedby_add = defaultdict(lambda: [])
viewed_by_temp = {} viewedby_remove = defaultdict(lambda: [])
for k, v in viewed_by.iteritems(): # Use to track viewed_by changes
viewed_by_temp[k] = list(v)
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
...@@ -865,9 +854,11 @@ class ProfileStats(object): ...@@ -865,9 +854,11 @@ class ProfileStats(object):
# input. # input.
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
# We keep trac of view only again the original # We keep trac of view only again the original
origin = view_of_temp.get(ins, ins) origin = view_of.get(ins, ins)
view_of_temp[out] = origin view_of[out] = origin
viewed_by_temp[origin].append(out) viewof_change.append(out)
viewed_by[origin].append(out)
viewedby_add[origin].append(out)
else: else:
mem_created += var_mem[out] mem_created += var_mem[out]
idx += 1 idx += 1
...@@ -877,19 +868,20 @@ class ProfileStats(object): ...@@ -877,19 +868,20 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc. # Mimic the combination of Theano and Python gc.
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of_temp and assert not (ins in view_of and
viewed_by_temp[ins]) viewed_by[ins])
# We track of the original var, so this shouldn't happen # We track of the original var, so this shouldn't happen
if (dependencies[ins] and if (dependencies[ins] and
ins not in fgraph.outputs and ins not in fgraph.outputs and
ins.owner and ins.owner and
all([compute_map[v][0] for v in dependencies[ins]])): all([compute_map[v][0] for v in dependencies[ins]])):
if ins not in view_of_temp and not viewed_by_temp.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
mem_freed += var_mem[ins] mem_freed += var_mem[ins]
elif ins in view_of_temp: elif ins in view_of:
origin = view_of_temp[ins] origin = view_of[ins]
viewed_by_temp[origin].remove(ins) viewed_by[origin].remove(ins)
if (not viewed_by_temp[origin] and viewedby_remove[origin].append(ins)
if (not viewed_by[origin] and
origin not in fgraph.inputs and origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)): not isinstance(origin, theano.Constant)):
mem_freed += var_mem[origin] mem_freed += var_mem[origin]
...@@ -902,19 +894,17 @@ class ProfileStats(object): ...@@ -902,19 +894,17 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
for c, _ in var.clients: for c, _ in var.clients:
if c != "output" and check_node_state(c): if c != "output":
deps = c.inputs + c.destroy_dependencies
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c) new_exec_nodes.add(c)
if not new_exec_nodes: if not new_exec_nodes:
yield [node]
# Check and Update mem_bound # Check and Update mem_bound
if max_mem_count < mem_bound: if max_mem_count < mem_bound:
mem_bound = max_mem_count mem_bound = max_mem_count
else: else:
for p in min_memory_generator(new_exec_nodes, min_memory_generator(new_exec_nodes, viewed_by, view_of)
viewed_by_temp,
view_of_temp):
yield [node]+p
# Reset track variables # Reset track variables
mem_count -= mem_created mem_count -= mem_created
...@@ -923,6 +913,18 @@ class ProfileStats(object): ...@@ -923,6 +913,18 @@ class ProfileStats(object):
for var in node.outputs: for var in node.outputs:
compute_map[var][0] = 0 compute_map[var][0] = 0
for k_remove, v_remove in viewedby_remove.iteritems():
for i in v_remove:
viewed_by[k_remove].append(i)
for k_add, v_add in viewedby_add.iteritems():
for i in v_add:
viewed_by[k_add].remove(i)
for k in viewof_change:
del view_of[k]
# two data structure used to mimic Python gc # two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]} viewed_by = {} # {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref count. But we use a list, not just the ref count value.
...@@ -932,11 +934,7 @@ class ProfileStats(object): ...@@ -932,11 +934,7 @@ class ProfileStats(object):
view_of = {} # {var1: original var viewed by var1} view_of = {} # {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
# Loop all valid orders and find min peak(store in mem_bound) min_memory_generator(executable_nodes, viewed_by, view_of)
for order in min_memory_generator(executable_nodes,
viewed_by,
view_of):
continue
return mem_bound return mem_bound
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论