提交 5e07df51 authored 作者: James Bergstra's avatar James Bergstra

added ProfileMode to sandbox/wraplinker

上级 02818b45
from __future__ import absolute_import from __future__ import absolute_import
import time
import numpy import numpy
from ..gof.link import WrapLinker from ..gof.link import WrapLinker
from ..gradient import numeric_grad from ..compile.mode import Mode
class Todo(Exception): """todo""" class Todo(Exception): """todo"""
#WrapLinker wrappers #WrapLinker wrappers
def cmp_outputs(i, node, *thunks): if 0:
"""WrapLinker wrapper: raise an exception if outputs are different from ..gradient import numeric_grad
def cmp_outputs(i, node, *thunks):
"""WrapLinker wrapper: raise an exception if outputs are different
numpy.ndarrays of floating point types are compared approximately, rather numpy.ndarrays of floating point types are compared approximately, rather
than exactly. than exactly.
""" """
class MisMatch(Exception): """Output mismatch""" class MisMatch(Exception): """Output mismatch"""
#define a comparison function, which works for all the results in a graph #define a comparison function, which works for all the results in a graph
#TODO: consider factoring this out (and maybe passing args explicitly #TODO: consider factoring this out (and maybe passing args explicitly
# instead of by closure) # instead of by closure)
def my_check_equal(x, y): def my_check_equal(x, y):
if type(x) != type(y): if type(x) != type(y):
raise MisMatch("Output type mismatch", (x, y)) raise MisMatch("Output type mismatch", (x, y))
if hasattr(x, 'dtype'): if hasattr(x, 'dtype'):
# was: isinstance(x,numpy.ndarray), which doesn't # was: isinstance(x,numpy.ndarray), which doesn't
# catch numpy.float64 # catch numpy.float64
if x.dtype != y.dtype or x.shape != y.shape: if x.dtype != y.dtype or x.shape != y.shape:
raise MisMatch("ndarray type/shape.", (x,y)) raise MisMatch("ndarray type/shape.", (x,y))
if str(x.dtype).startswith('float'): if str(x.dtype).startswith('float'):
assert str(x.dtype) == 'float64' #otherwise we need to adjust assert str(x.dtype) == 'float64' #otherwise we need to adjust
#our constant below... but to what? #our constant below... but to what?
abs_rel_err = numeric_grad.abs_rel_err(x, y) abs_rel_err = numeric_grad.abs_rel_err(x, y)
max_abs_rel_err = numpy.max(abs_rel_err) max_abs_rel_err = numpy.max(abs_rel_err)
if max_abs_rel_err > 1.0e-7: if max_abs_rel_err > 1.0e-7:
raise MisMatch('max_abs_rel_err exceeds tolerence', (max_abs_rel_err, raise MisMatch('max_abs_rel_err exceeds tolerence', (max_abs_rel_err,
x, y)) x, y))
elif str(x.dtype).startswith('complex'): elif str(x.dtype).startswith('complex'):
raise Todo() raise Todo()
else:
if not numpy.all(x==y):
raise MisMatch
else: else:
if not numpy.all(x==y): print 'wtf??', type(x), type(y), node.op
raise MisMatch if x != y:
print 'wow!! wtf??'
else: raise MisMatch("Output mismatch.", (x, y))
print 'wtf??', type(x), type(y), node.op
if x != y: #loop over all the thunks
print 'wow!! wtf??' # ensure that the outputs from the first thunk match the outputs from
raise MisMatch("Output mismatch.", (x, y)) # all subsequent thunks
n_thunks = len(thunks)
#loop over all the thunks if n_thunks > 1:
# ensure that the outputs from the first thunk match the outputs from th0 = thunks[0]
# all subsequent thunks for th in thunks[1:]:
n_thunks = len(thunks) for out0, outN in zip(th0.outputs, th.outputs):
if n_thunks > 1: my_check_equal(out0[0], outN[0])
th0 = thunks[0]
for th in thunks[1:]:
for out0, outN in zip(th0.outputs, th.outputs):
my_check_equal(out0[0], outN[0])
#TODO: better name for 'f' #TODO: better name for 'f'
def numpy_wrapper(f): def numpy_wrapper(f):
...@@ -98,3 +101,51 @@ def WrapLinkerMany(linkers, wrappers): ...@@ -98,3 +101,51 @@ def WrapLinkerMany(linkers, wrappers):
def DualLinker(linkers): def DualLinker(linkers):
return WrapLinkerMany(linkers, [run_all, cmp_outputs]) return WrapLinkerMany(linkers, [run_all, cmp_outputs])
class ProfileMode(Mode):
def __init__(self, local_linker, optimizer=None):
local_time = [0.0]
apply_time = {}
op_time = {}
def blah(i, node, *thunks):
t0 = time.time()
for th in thunks:
th()
dt = time.time() - t0
local_time[0] += dt
apply_time[(i,node.op)] = apply_time.get((i,node.op), 0.0) + dt
op_time[node.op] = op_time.get(node.op, 0.0) + dt
self.local_time = local_time
self.apply_time = apply_time
self.op_time = op_time
linker = WrapLinkerMany([local_linker], [blah])
if optimizer:
Mode.__init__(self, linker, optimizer)
else:
Mode.__init__(self, linker)
def print_summary(self):
local_time = self.local_time[0]
apply_time = self.apply_time
op_time = self.op_time
print 'local_time', local_time
print 'apply-wise times'
atimes = [(t/local_time, (a[0], str(a[1]))) for a, t in apply_time.items()]
atimes.sort()
atimes.reverse()
for t,a in atimes[:15]:
print ' ', t, a
print ' ...' #show that we are ignoring applies that don't take much time
print 'op-wise times'
otimes = [(t/local_time, a) for a, t in op_time.items()]
otimes.sort()
otimes.reverse()
for t,a in otimes[:15]:
print ' ', t, a
print ' ...' #show that we are ignoring applies that don't take much time
print sum(t for a,t in op_time.items())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论