提交 3a0e1c42 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the print_extra print to the right place.

上级 27cb3e25
...@@ -1325,7 +1325,7 @@ class ProfileStats(object): ...@@ -1325,7 +1325,7 @@ class ProfileStats(object):
print("-----------------", file=file) print("-----------------", file=file)
self.optimizer_profile[0].print_profile(file, self.optimizer_profile[0].print_profile(file,
self.optimizer_profile[1]) self.optimizer_profile[1])
self.print_extra() self.print_extra(file)
self.print_tips(file) self.print_tips(file)
def print_tips(self, file): def print_tips(self, file):
...@@ -1472,11 +1472,11 @@ class ProfileStats(object): ...@@ -1472,11 +1472,11 @@ class ProfileStats(object):
if not printed_tip: if not printed_tip:
print(" Sorry, no tip for today.", file=file) print(" Sorry, no tip for today.", file=file)
def print_extra(self): def print_extra(self, file):
params = [self.message, self.compile_time, self.fct_call_time, params = [self.message, self.compile_time, self.fct_call_time,
self.apply_time, self.apply_cimpl, self.output_size] self.apply_time, self.apply_cimpl, self.output_size]
for f in _profiler_printers: for f in _profiler_printers:
f(*params) f(*params, file=file)
class ScanProfileStats(ProfileStats): class ScanProfileStats(ProfileStats):
......
...@@ -1542,12 +1542,12 @@ class GpuSplit(HideC, Split): ...@@ -1542,12 +1542,12 @@ class GpuSplit(HideC, Split):
@theano.compile.profiling.register_profiler_printer @theano.compile.profiling.register_profiler_printer
def profile_printer(message, compile_time, fct_call_time, def profile_printer(message, compile_time, fct_call_time,
apply_time, apply_cimpl, outputs_size): apply_time, apply_cimpl, outputs_size, file):
if any([x.op.__class__.__name__.lower().startswith("gpu") if any([x.op.__class__.__name__.lower().startswith("gpu")
for x in apply_time.keys()]): for x in apply_time.keys()]):
local_time = sum(apply_time.values()) local_time = sum(apply_time.values())
print() print('', file=file)
print('Some info useful for gpu:') print('Some info useful for gpu:', file=file)
fgraphs = set() fgraphs = set()
for node in apply_time.keys(): for node in apply_time.keys():
...@@ -1563,23 +1563,23 @@ def profile_printer(message, compile_time, fct_call_time, ...@@ -1563,23 +1563,23 @@ def profile_printer(message, compile_time, fct_call_time,
gpu += t gpu += t
else: else:
cpu += t cpu += t
print() print('', file=file)
print(" Spent %.3fs(%.2f%%) in cpu Op, %.3fs(%.2f%%) in gpu Op and %.3fs(%.2f%%) transfert Op" % ( print(" Spent %.3fs(%.2f%%) in cpu Op, %.3fs(%.2f%%) in gpu Op and %.3fs(%.2f%%) transfert Op" % (
cpu, cpu / local_time * 100, gpu, gpu / local_time * 100, cpu, cpu / local_time * 100, gpu, gpu / local_time * 100,
trans, trans / local_time * 100)) trans, trans / local_time * 100), file=file)
print() print('', file=file)
print(" Theano function input that are float64") print(" Theano function input that are float64", file=file)
print(" <fct name> <input name> <input type> <str input>") print(" <fct name> <input name> <input type> <str input>", file=file)
for fg in fgraphs: for fg in fgraphs:
for i in fg.inputs: for i in fg.inputs:
if hasattr(i.type, 'dtype') and i.type.dtype == 'float64': if hasattr(i.type, 'dtype') and i.type.dtype == 'float64':
print(' ', fg.name, i.name, i.type, i) print(' ', fg.name, i.name, i.type, i, file=file)
print() print('', file=file)
print(" List of apply that don't have float64 as input but have float64 in outputs") print(" List of apply that don't have float64 as input but have float64 in outputs", file=file)
print(" (Useful to know if we forgot some cast when using floatX=float32 or gpu code)") print(" (Useful to know if we forgot some cast when using floatX=float32 or gpu code)", file=file)
print(' <Apply> <Apply position> <fct name> <inputs type> <outputs type>') print(' <Apply> <Apply position> <fct name> <inputs type> <outputs type>', file=file)
for fg in fgraphs: for fg in fgraphs:
for idx, node in enumerate(fg.toposort()): for idx, node in enumerate(fg.toposort()):
if (any(hasattr(i, 'dtype') and i.dtype == 'float64' if (any(hasattr(i, 'dtype') and i.dtype == 'float64'
...@@ -1587,11 +1587,13 @@ def profile_printer(message, compile_time, fct_call_time, ...@@ -1587,11 +1587,13 @@ def profile_printer(message, compile_time, fct_call_time,
not any(hasattr(i, 'dtype') and i.dtype == 'float64' not any(hasattr(i, 'dtype') and i.dtype == 'float64'
for i in node.inputs)): for i in node.inputs)):
print(' ', str(node), idx, fg.name, end=' ') print(' ', str(node), idx, fg.name, end=' ',
file=file)
print(str([getattr(i, 'dtype', None) print(str([getattr(i, 'dtype', None)
for i in node.inputs]), end=' ') for i in node.inputs]), end=' ', file=file)
print(str([getattr(i, 'dtype', None) print(str([getattr(i, 'dtype', None)
for i in node.outputs])) for i in node.outputs]), file=file)
print('', file=file)
class GpuEye(GpuKernelBase, Op): class GpuEye(GpuKernelBase, Op):
......
...@@ -2869,15 +2869,15 @@ gof.ops_with_inner_function[Scan] = 'fn' ...@@ -2869,15 +2869,15 @@ gof.ops_with_inner_function[Scan] = 'fn'
@theano.compile.profiling.register_profiler_printer @theano.compile.profiling.register_profiler_printer
def profile_printer(message, compile_time, fct_call_time, def profile_printer(message, compile_time, fct_call_time,
apply_time, apply_cimpl, outputs_size): apply_time, apply_cimpl, outputs_size, file):
# Scan overhead profile # Scan overhead profile
if any([isinstance(node.op, Scan) and v > 0 for node, v in if any([isinstance(node.op, Scan) and v > 0 for node, v in
apply_time.items()]): apply_time.items()]):
print() print('', file=file)
print('Scan overhead:') print('Scan overhead:', file=file)
print('<Scan op time(s)> <sub scan fct time(s)> <sub scan op ' print('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan ' 'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>') 'op time(% scan op time)> <node>', file=file)
fct_call = set() fct_call = set()
for node in apply_time.keys(): for node in apply_time.keys():
...@@ -2899,13 +2899,13 @@ def profile_printer(message, compile_time, fct_call_time, ...@@ -2899,13 +2899,13 @@ def profile_printer(message, compile_time, fct_call_time,
scan_fct_time, scan_fct_time,
scan_op_time, scan_op_time,
scan_fct_time / v * 100, scan_fct_time / v * 100,
scan_op_time / v * 100), node) scan_op_time / v * 100), node, file=file)
else: else:
print((' The node took 0s, so we can not ' print((' The node took 0s, so we can not '
'compute the overhead'), node) 'compute the overhead'), node, file=file)
print(' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%' % ( print(' total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%' % (
total_super_scan_time, total_super_scan_time,
total_scan_fct_time, total_scan_fct_time,
total_scan_op_time, total_scan_op_time,
total_scan_fct_time / total_super_scan_time * 100, total_scan_fct_time / total_super_scan_time * 100,
total_scan_op_time / total_super_scan_time * 100)) total_scan_op_time / total_super_scan_time * 100), file=file)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论