提交 31fdf12a authored 作者: Frederic Bastien's avatar Frederic Bastien

fix profile printing of scan when it is not executed.

上级 256b7f59
...@@ -1974,7 +1974,7 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call, ...@@ -1974,7 +1974,7 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, op_cimpl, message, outputs_size, apply_time, op_cimpl, message, outputs_size,
other_time): other_time):
# Scan overhead profile # Scan overhead profile
if any([isinstance(node.op, (Scan, ScanGrad)) for (_,node) in apply_time.keys()]): if any([isinstance(node.op, (Scan, ScanGrad)) and v>0 for (_,node),v in apply_time.items()]):
print print
print 'Scan overhead:' print 'Scan overhead:'
print '<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time/scan op time(%)> <sub scan op time/scan op time(%)> <node>' print '<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time/scan op time(%)> <sub scan op time/scan op time(%)> <node>'
...@@ -1982,7 +1982,7 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call, ...@@ -1982,7 +1982,7 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
total_scan_fct_time = 0 total_scan_fct_time = 0
total_scan_op_time = 0 total_scan_op_time = 0
for (_,node),v in apply_time.items(): for (_,node),v in apply_time.items():
if isinstance(node.op, (Scan, ScanGrad)): if isinstance(node.op, (Scan, ScanGrad)) and v > 0:
scan_fct_time = sum(node.op.mode_instance.fct_call_time.values()) scan_fct_time = sum(node.op.mode_instance.fct_call_time.values())
scan_op_time = sum(node.op.mode_instance.local_time) scan_op_time = sum(node.op.mode_instance.local_time)
total_super_scan_time += v total_super_scan_time += v
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论