提交 12b11541 authored 作者: Roy Xue's avatar Roy Xue

updates:

1. remove compute_map.copy from vm.py 2. add new compute_map in compile.py 3. make code more simple
上级 2aca2e14
...@@ -19,6 +19,7 @@ import copy ...@@ -19,6 +19,7 @@ import copy
import os import os
import sys import sys
import time import time
from collections import defaultdict
import numpy import numpy
...@@ -699,12 +700,11 @@ class ProfileStats(object): ...@@ -699,12 +700,11 @@ class ProfileStats(object):
min_mem = 0 min_mem = 0
current_mem = 0 current_mem = 0
compute_map = fgraph.profile.compute_map compute_map = defaultdict(lambda: [0])
# compute_map use to check if a node is valid # compute_map use to check if a node is valid
for node in node_list: for node in fgraph.inputs:
for v in node.outputs: compute_map[node][0] = 1
compute_map[v][0] = 0
def check_node_state(node): def check_node_state(node):
""" """
...@@ -748,14 +748,9 @@ class ProfileStats(object): ...@@ -748,14 +748,9 @@ class ProfileStats(object):
for i in v[0].outputs: for i in v[0].outputs:
compute_map[i][0] = 0 compute_map[i][0] = 0
temp = []
min_order = [] min_order = []
for order in min_memory_generator(node_list, compute_map): for order in min_memory_generator(node_list, compute_map):
temp.append(order)
for order in temp:
post_thunk_old_storage = [] post_thunk_old_storage = []
computed, last_user = theano.gof.link.gc_helper(order) computed, last_user = theano.gof.link.gc_helper(order)
for node in order: for node in order:
...@@ -766,11 +761,10 @@ class ProfileStats(object): ...@@ -766,11 +761,10 @@ class ProfileStats(object):
(input not in fgraph.outputs) and (input not in fgraph.outputs) and
node == last_user[input]]) node == last_user[input]])
current_mem = count_running_memory(order, post_thunk_old_storage, nodes_mem)[2] current_mem = count_running_memory(order, post_thunk_old_storage, nodes_mem)[2]
current_order = order
if current_mem < min_mem: if current_mem < min_mem:
min_mem = current_mem min_mem = current_mem
min_order = current_order min_order = order
return min_order, min_mem return min_order, min_mem
......
...@@ -147,9 +147,6 @@ class VM(object): ...@@ -147,9 +147,6 @@ class VM(object):
if hasattr(self, 'node_cleared_order'): if hasattr(self, 'node_cleared_order'):
profile.node_cleared_order = self.node_cleared_order[:] profile.node_cleared_order = self.node_cleared_order[:]
if hasattr(self, 'compute_map'):
profile.compute_map = self.compute_map.copy()
# clear the timer info out of the buffers # clear the timer info out of the buffers
for i in xrange(len(self.call_times)): for i in xrange(len(self.call_times)):
self.call_times[i] = 0.0 self.call_times[i] = 0.0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论