added Profiler class, profiler option to PerformLinker and OpWiseCLinker

上级 eaf79124
import unittest import unittest
from link import PerformLinker from link import PerformLinker, Profiler
from cc import * from cc import *
from result import ResultBase from result import ResultBase
from op import Op from op import Op
...@@ -188,6 +188,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -188,6 +188,7 @@ class _test_CLinker(unittest.TestCase):
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0) self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
class _test_OpWiseCLinker(unittest.TestCase): class _test_OpWiseCLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
......
from link import Linker from link import Linker, raise_with_op
from copy import copy from copy import copy
from utils import AbstractFunctionError from utils import AbstractFunctionError
import md5 import md5
...@@ -717,10 +717,15 @@ class OpWiseCLinker(Linker): ...@@ -717,10 +717,15 @@ class OpWiseCLinker(Linker):
the whole env, but saves on compilation time because small changes the whole env, but saves on compilation time because small changes
in the computation graph won't necessarily trigger any recompilation, in the computation graph won't necessarily trigger any recompilation,
only local changes in the Results or Ops that are used. only local changes in the Results or Ops that are used.
If fallback_on_perform is True, OpWiseCLinker will use an op's
perform method if no C version can be generated.
""" """
def __init__(self, env): def __init__(self, env, profiler = None, fallback_on_perform = True):
self.env = env self.env = env
self.profiler = profiler
self.fallback_on_perform = fallback_on_perform
def make_thunk(self, inplace = False): def make_thunk(self, inplace = False):
if inplace: if inplace:
...@@ -732,15 +737,33 @@ class OpWiseCLinker(Linker): ...@@ -732,15 +737,33 @@ class OpWiseCLinker(Linker):
env = None env = None
thunks = [] thunks = []
for op in op_order: for op in op_order:
try:
cl = CLinker(op) cl = CLinker(op)
thunk, in_results, out_results = cl.make_thunk(True) thunk, in_results, out_results = cl.make_thunk(True)
thunks.append(thunk) thunks.append(thunk)
except AbstractFunctionError:
if self.fallback_on_perform:
thunks.append(op.perform)
else:
raise
def execute(): if self.profiler is None:
for thunk in thunks: def f():
try:
for thunk, op in zip(thunks, op_order):
thunk() thunk()
except:
raise_with_op(op)
else:
profiler = self.profiler()
def f():
def g():
for thunk, op in zip(thunks, op_order):
profiler.profile_op(thunk, op)
profiler.profile_env(g, env)
f.profiler = profiler
return execute, inputs, outputs return f, inputs, outputs
...@@ -748,7 +771,7 @@ class OpWiseCLinker(Linker): ...@@ -748,7 +771,7 @@ class OpWiseCLinker(Linker):
def _default_checker(x, y): def _default_checker(x, y):
""" """
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
results results contain the same data using ==.
""" """
if x.data != y.data: if x.data != y.data:
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data}) raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data})
...@@ -812,14 +835,15 @@ class DualLinker(Linker): ...@@ -812,14 +835,15 @@ class DualLinker(Linker):
for output1, output2 in zip(op1.outputs, op2.outputs): for output1, output2 in zip(op1.outputs, op2.outputs):
self.checker(output1, output2) self.checker(output1, output2)
except: except:
exc_type, exc_value, exc_trace = sys.exc_info() raise_with_op(op1)
try: # exc_type, exc_value, exc_trace = sys.exc_info()
trace = op1.trace # try:
except AttributeError: # trace = op1.trace
trace = () # except AttributeError:
exc_value.__thunk_trace__ = trace # trace = ()
exc_value.args = exc_value.args + (op1, ) # exc_value.__thunk_trace__ = trace
raise exc_type, exc_value, exc_trace # exc_value.args = exc_value.args + (op1, )
# raise exc_type, exc_value, exc_trace
return f, env1.inputs, env1.outputs return f, env1.inputs, env1.outputs
......
...@@ -31,11 +31,20 @@ def thunk_hook(type, value, trace): ...@@ -31,11 +31,20 @@ def thunk_hook(type, value, trace):
sys.excepthook = thunk_hook sys.excepthook = thunk_hook
def raise_with_op(op, exc_info = None):
if exc_info is None:
exc_info = sys.exc_info()
exc_type, exc_value, exc_trace = exc_info
try:
trace = op.trace
except AttributeError:
trace = ()
exc_value.__thunk_trace__ = trace
exc_value.args = exc_value.args + (op, )
raise exc_type, exc_value, exc_trace
class Linker:
def __init__(self, env): class Linker:
self.env = env
def make_thunk(self, inplace = False): def make_thunk(self, inplace = False):
""" """
...@@ -91,6 +100,11 @@ class Linker: ...@@ -91,6 +100,11 @@ class Linker:
return utils.to_return_values([result.data for result in outputs]) return utils.to_return_values([result.data for result in outputs])
else: else:
return [result.data for result in outputs] return [result.data for result in outputs]
execute.thunk = thunk
try:
execute.profiler = thunk.profiler
except AttributeError:
pass
return execute return execute
...@@ -103,6 +117,10 @@ class PerformLinker(Linker): ...@@ -103,6 +117,10 @@ class PerformLinker(Linker):
the env in the order given by env.toposort. the env in the order given by env.toposort.
""" """
def __init__(self, env, profiler = None):
self.env = env
self.profiler = profiler
def make_thunk(self, inplace = False): def make_thunk(self, inplace = False):
if inplace: if inplace:
env = self.env env = self.env
...@@ -110,81 +128,130 @@ class PerformLinker(Linker): ...@@ -110,81 +128,130 @@ class PerformLinker(Linker):
env = self.env.clone(True) env = self.env.clone(True)
order = env.toposort() order = env.toposort()
thunks = [op.perform for op in order] thunks = [op.perform for op in order]
if self.profiler is None:
def f(): def f():
try: try:
for thunk, op in zip(thunks, order): for thunk, op in zip(thunks, order):
thunk() thunk()
except: except:
exc_type, exc_value, exc_trace = sys.exc_info() raise_with_op(op)
try: else:
trace = op.trace profiler = self.profiler()
except AttributeError: def f():
trace = () def g():
exc_value.__thunk_trace__ = trace for thunk, op in zip(thunks, order):
exc_value.args = exc_value.args + (op, ) profiler.profile_op(thunk, op)
raise exc_type, exc_value, exc_trace profiler.profile_env(g, env)
f.profiler = profiler
return f, env.inputs, env.outputs return f, env.inputs, env.outputs
### PROFILEPERFORMLINKER USES COMPLETELY OUTDATED INTERFACE - FIX ### from collections import defaultdict
import time
# class ProfilePerformLinker(Linker):
class Stats:
# def compile(self): def __init__(self):
# order = self.env.toposort() self.ncalls = 0
# thunks = [op.perform for op in order] self.time = 0
# self.n_calls = 0 self.nfailures = 0
# self.n_thunks = 0 self.time_failures = 0
# self.times = [0.0 for op in self.order] def inc_ncalls(self, v): self.ncalls += v
# def f(): def inc_time(self, v): self.time += v
# for thunk in thunks: def inc_nfailures(self, v): self.nfailures += v
# thunk() def inc_time_failures(self, v): self.time_failures += v
# self.thunk = f
# self.order = order class Profiler:
# self.thunks = thunks """
Collects performance statistics on a function on a per-op
# def slow_call(self): or per-op-class basis.
# """Run the program, timing each thunk.""" """
# for i, thunk in enumerate(self.thunks):
# start_time = time.time() def __init__(self, ignore = [], by_class = True):
# thunk() """
# self.times[i] += time.time() - start_time Creates a Profiler. If by_class is True, stats will
# self.n_thunks += 1 be collected for each Op class, adding the totals for
# self.n_calls += 1 each occurrence of that Op in the computation. If
by_class is False, each node will be timed individually.
# def fast_call(self):
# """Run the program, but only time the entire loop.""" All op classes or ops (depending on the value of by_class)
# start_time = time.time() listed in ignore will not be timed.
# for thunk in self.thunks: """
# thunk() self.ignore = ignore
# self.n_thunks += len(self.thunks) self.stats = defaultdict(Stats)
# self.n_calls += 1 self.started = {}
# self.times[0] += time.time() - start_time self.by_class = by_class
# __call__ = slow_call def profile_env(self, f, env):
stats = self.stats['TOTAL']
# def dump(self, proportion=True): n, t = stats.inc_ncalls, stats.inc_time
# """Print statistics accumulated so far.""" failed = False
# total_time = sum(self.times)
# print self.n_calls, 'calls took', total_time, 'seconds to evaluate', start = time.time()
# print self.n_thunks, 'thunks' try:
f()
# if 0: end = time.time()
# print 'Proportion of CPU per op' except:
# for op, t in zip(self.order, self.times): end = time.time()
# s_op = str(op).split()[0][1:] n, t = stats.inc_nfailures, stats.inc_times_failures
# print " %-35s %4.5f"% (s_op, t/total_time) failed = True
ety, eva, etr = sys.exc_info()
# print 'Proportion of CPU per op class' n(1)
# dct = {} t(end - start)
# for op, t in zip(self.order, self.times): if failed:
# s_op = str(op).split()[0][1:] raise ety, eva, etr
# dct[s_op] = dct.get(s_op, 0.0) + t
# for t, s_op in reversed(sorted([(t,op) for op, t in dct.items()])): def profile_op(self, f, op):
# if proportion: if self.by_class:
# print " %-35s %4.5f"% (s_op, t/total_time) entry = op.__class__
# else: else:
# print " %-35s %4.5f"% (s_op, t) entry = op
stats = self.stats[entry]
n, t = stats.inc_ncalls, stats.inc_time
failed = False
start = time.time()
try:
f()
end = time.time()
except:
end = time.time()
n, t = stats.inc_nfailures, stats.inc_times_failures
failed = True
exc = sys.exc_info()
if entry not in self.ignore:
n(1)
t(end - start)
if failed:
raise_with_op(op, exc)
def print_stats(self, sort_by = 'time'):
def compare_fn((op1, stat1), (op2, stat2)):
x1 = getattr(stat2, sort_by)
x2 = getattr(stat1, sort_by)
if x1 > x2:
return 1
elif x1 < x2:
return -1
else:
return 0
totals = self.stats['TOTAL']
print 'CPU usage statistics'
print " %-25s %9s %12s %12s %12s" % (("Op%s" % (self.by_class and ' class' or '')), 'NCALLS', 'PER_CALL', 'TOTAL', 'CPU%')
for op, stat in sorted(self.stats.items(), compare_fn):
if op == 'TOTAL': continue
to_print = self.by_class and (op.__module__ + "." + op.__name__) or str(op)
print " %-25s %9i %12.5f %12.5f %12.5f" % (to_print, stat.ncalls, stat.time / stat.ncalls, stat.time, stat.time / totals.time)
stat = self.stats['TOTAL']
print " %-25s %9i %12.5f %12.5f %12.5f" % ('TOTAL (includes overhead)', stat.ncalls, stat.time / stat.ncalls, stat.time, stat.time / totals.time)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论