提交 97acb91b authored 作者: Roy Xue's avatar Roy Xue

new test_profiling.py

increase the complexity of test.
上级 12b11541
...@@ -737,16 +737,16 @@ class ProfileStats(object): ...@@ -737,16 +737,16 @@ class ProfileStats(object):
for i in range(len(node_list)): for i in range(len(node_list)):
v = node_list[i:i+1] v = node_list[i:i+1]
if check_node_state(v[0]): if check_node_state(v[0]):
for i in v[0].outputs: for node in v[0].outputs:
compute_map[i][0] = 1 compute_map[node][0] = 1
if len(node_list) == 1: if len(node_list) == 1:
yield v yield v
else: else:
rest = node_list[ :i] + node_list[i+1: ] rest = node_list[ :i] + node_list[i+1: ]
for p in min_memory_generator(rest, compute_map): for p in min_memory_generator(rest, compute_map):
yield v+p yield v+p
for i in v[0].outputs: for node in v[0].outputs:
compute_map[i][0] = 0 compute_map[node][0] = 0
min_order = [] min_order = []
......
...@@ -15,20 +15,34 @@ def test_profiling(): ...@@ -15,20 +15,34 @@ def test_profiling():
theano.config.profile = True theano.config.profile = True
theano.config.profile_memory = True theano.config.profile_memory = True
x = T.dvector("x") val1 = T.dvector("val1")
y = T.dvector("y") val2 = T.dvector("val2")
z = x + y val3 = T.dvector("val3")
val4 = T.dvector("val4")
val5 = T.dvector("val5")
val6 = T.dvector("val6")
val7 = T.dvector("val7")
val8 = T.dvector("val8")
val9 = T.dvector("val9")
x = [val1, val2, val3, val4, val5, val6, val7, val8, val9]
z = [x[i] + x[i+1] for i in range(len(x)-1)] + [T.outer(x[i], x[i+1]).sum() for i in range(len(x)-1)]
p = theano.ProfileStats(False) p = theano.ProfileStats(False)
if theano.config.mode in ["DebugMode", "DEBUG_MODE"]: if theano.config.mode in ["DebugMode", "DEBUG_MODE"]:
m = "FAST_RUN" m = "FAST_RUN"
else: else:
m = None m = None
f = theano.function([x, y], z, profile=p, name="test_profiling",
f = theano.function([val1, val2, val3, val4, val5, val6, val7, val8, val9], z, profile=p, name="test_profiling",
mode=m) mode=m)
output = f([1, 2, 3, 4], [1, 1, 1, 1])
output = f([0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8], [5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12])
buf = StringIO.StringIO() buf = StringIO.StringIO()
f.profile.summary(buf) f.profile.summary(buf)
finally: finally:
theano.config.profile = old1 theano.config.profile = old1
theano.config.profile_memory = old2 theano.config.profile_memory = old2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论