提交 2bb3bd57 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed the way Profiler is used

上级 5b24c9e5
import gof import gof
from core import * from base_tensor import *
from opt import * from tensor import *
from compile import * from compile import *
from grad import * #from sparse import *
from opt import *
from gradient import *
...@@ -40,6 +40,7 @@ class Function: ...@@ -40,6 +40,7 @@ class Function:
features = [], features = [],
optimizer = None, optimizer = None,
linker_cls = gof.link.PerformLinker, linker_cls = gof.link.PerformLinker,
profiler = None,
unpack_single = True, unpack_single = True,
except_unreachable_input = True, except_unreachable_input = True,
keep_locals = True): keep_locals = True):
...@@ -86,13 +87,19 @@ class Function: ...@@ -86,13 +87,19 @@ class Function:
# optimize and link the cloned env # optimize and link the cloned env
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
linker = linker_cls(env) linker = linker_cls(env)
if keep_locals:# useful flag for debugging! if keep_locals:# useful flag for debugging!
self.__dict__.update(locals()) self.__dict__.update(locals())
if profiler is None:
self.fn = linker.make_function(inplace=True, self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single) unpack_single=unpack_single)
else:
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single,
profiler=profiler)
def __call__(self, *args): def __call__(self, *args):
return self.fn(*args) return self.fn(*args)
......
...@@ -722,12 +722,11 @@ class OpWiseCLinker(Linker): ...@@ -722,12 +722,11 @@ class OpWiseCLinker(Linker):
perform method if no C version can be generated. perform method if no C version can be generated.
""" """
def __init__(self, env, profiler = None, fallback_on_perform = True): def __init__(self, env, fallback_on_perform = True):
self.env = env self.env = env
self.profiler = profiler
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
def make_thunk(self, inplace = False): def make_thunk(self, inplace = False, profiler = None):
if inplace: if inplace:
env = self.env env = self.env
else: else:
...@@ -747,7 +746,7 @@ class OpWiseCLinker(Linker): ...@@ -747,7 +746,7 @@ class OpWiseCLinker(Linker):
else: else:
raise raise
if self.profiler is None: if profiler is None:
def f(): def f():
try: try:
for thunk, op in zip(thunks, op_order): for thunk, op in zip(thunks, op_order):
...@@ -755,7 +754,6 @@ class OpWiseCLinker(Linker): ...@@ -755,7 +754,6 @@ class OpWiseCLinker(Linker):
except: except:
raise_with_op(op) raise_with_op(op)
else: else:
profiler = self.profiler()
def f(): def f():
def g(): def g():
for thunk, op in zip(thunks, op_order): for thunk, op in zip(thunks, op_order):
......
...@@ -66,7 +66,7 @@ class Linker: ...@@ -66,7 +66,7 @@ class Linker:
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def make_function(self, inplace = False, unpack_single = True): def make_function(self, inplace = False, unpack_single = True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
env used by this Linker and returns values corresponding the the outputs env used by this Linker and returns values corresponding the the outputs
...@@ -85,7 +85,7 @@ class Linker: ...@@ -85,7 +85,7 @@ class Linker:
output, then that output will be returned. Else, a list or tuple of output, then that output will be returned. Else, a list or tuple of
length 1 will be returned. length 1 will be returned.
""" """
thunk, inputs, outputs = self.make_thunk(inplace) thunk, inputs, outputs = self.make_thunk(inplace, **kwargs)
def execute(*args): def execute(*args):
def e_arity(takes, got): def e_arity(takes, got):
...@@ -101,10 +101,8 @@ class Linker: ...@@ -101,10 +101,8 @@ class Linker:
else: else:
return [result.data for result in outputs] return [result.data for result in outputs]
execute.thunk = thunk execute.thunk = thunk
try: execute.inputs = inputs
execute.profiler = thunk.profiler execute.outputs = outputs
except AttributeError:
pass
return execute return execute
...@@ -117,18 +115,17 @@ class PerformLinker(Linker): ...@@ -117,18 +115,17 @@ class PerformLinker(Linker):
the env in the order given by env.toposort. the env in the order given by env.toposort.
""" """
def __init__(self, env, profiler = None): def __init__(self, env):
self.env = env self.env = env
self.profiler = profiler
def make_thunk(self, inplace = False): def make_thunk(self, inplace = False, profiler = None):
if inplace: if inplace:
env = self.env env = self.env
else: else:
env = self.env.clone(True) env = self.env.clone(True)
order = env.toposort() order = env.toposort()
thunks = [op.perform for op in order] thunks = [op.perform for op in order]
if self.profiler is None: if profiler is None:
def f(): def f():
try: try:
for thunk, op in zip(thunks, order): for thunk, op in zip(thunks, order):
...@@ -136,7 +133,6 @@ class PerformLinker(Linker): ...@@ -136,7 +133,6 @@ class PerformLinker(Linker):
except: except:
raise_with_op(op) raise_with_op(op)
else: else:
profiler = self.profiler()
def f(): def f():
def g(): def g():
for thunk, op in zip(thunks, order): for thunk, op in zip(thunks, order):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论