提交 d86ea664 authored 作者: Frederic's avatar Frederic

theano-cache list now print more stuff

- print the number of module with X number of different key. - print the module whose key.pkl file take more then 1M.
上级 530bad75
...@@ -107,6 +107,7 @@ AddConfigVar('compiledir', ...@@ -107,6 +107,7 @@ AddConfigVar('compiledir',
def print_compiledir_content(): def print_compiledir_content():
max_key_file_size = 1 * 1024 * 1024
def flatten(a): def flatten(a):
if isinstance(a, (tuple, list, set)): if isinstance(a, (tuple, list, set)):
...@@ -121,11 +122,15 @@ def print_compiledir_content(): ...@@ -121,11 +122,15 @@ def print_compiledir_content():
table = [] table = []
more_than_one_ops = 0 more_than_one_ops = 0
zeros_op = 0 zeros_op = 0
big_key_files = []
total_key_sizes = 0
nb_keys = {}
for dir in os.listdir(compiledir): for dir in os.listdir(compiledir):
file = None file = None
try: try:
try: try:
file = open(os.path.join(compiledir, dir, "key.pkl"), 'rb') filename = os.path.join(compiledir, dir, "key.pkl")
file = open(filename, 'rb')
keydata = cPickle.load(file) keydata = cPickle.load(file)
ops = list(set([x for x in flatten(keydata.keys) ops = list(set([x for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Op)])) if isinstance(x, theano.gof.Op)]))
...@@ -137,6 +142,14 @@ def print_compiledir_content(): ...@@ -137,6 +142,14 @@ def print_compiledir_content():
types = list(set([x for x in flatten(keydata.keys) types = list(set([x for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type)])) if isinstance(x, theano.gof.Type)]))
table.append((dir, ops[0], types)) table.append((dir, ops[0], types))
size = os.path.getsize(filename)
total_key_sizes += size
if size > max_key_file_size:
big_key_files.append((dir, size, ops))
nb_keys.setdefault(len(keydata.keys), 0)
nb_keys[len(keydata.keys)] += 1
except IOError: except IOError:
pass pass
finally: finally:
...@@ -159,6 +172,24 @@ def print_compiledir_content(): ...@@ -159,6 +172,24 @@ def print_compiledir_content():
table_op_class = sorted(table_op_class.iteritems(), key=lambda t: t[1]) table_op_class = sorted(table_op_class.iteritems(), key=lambda t: t[1])
for op_class, nb in table_op_class: for op_class, nb in table_op_class:
print op_class, nb print op_class, nb
big_key_files = sorted(big_key_files, key=lambda t: str(t[1]))
big_total_size = sum([size for dir, size, ops in big_key_files])
print
print "Directory with a key file bigger then %d bytes" % max_key_file_size,
print "(probably they there is a big constant inside)"
print "There total are %d bytes on a total size of %d for key files" % (
big_total_size, total_key_sizes)
for dir, size, ops in big_key_files:
print dir, size, ops
nb_keys = sorted(nb_keys.iteritems(), key=lambda t: t[0])
print
print "Number of key for a compiled module"
print "nb key/nb module with that number of key"
for n_k, n_m in nb_keys:
print n_k, n_m
print ("Skipped %d files that contained more than" print ("Skipped %d files that contained more than"
" 1 op (was compiled with the C linker)" % more_than_one_ops) " 1 op (was compiled with the C linker)" % more_than_one_ops)
print ("Skipped %d files that contained 0 op " print ("Skipped %d files that contained 0 op "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论