提交 d34d552d authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

import gof
from core import *
from opt import *
from base_tensor import *
from tensor import *
from compile import *
from grad import *
#from sparse import *
from opt import *
from gradient import *
......@@ -48,6 +48,7 @@ class Function:
features = [],
optimizer = default_optimizer,
linker_cls = gof.link.PerformLinker,
profiler = None,
unpack_single = True,
except_unreachable_input = True,
keep_locals = True):
......@@ -94,13 +95,19 @@ class Function:
# optimize and link the cloned env
if None is not optimizer:
optimizer(env)
linker = linker_cls(env)
if keep_locals:# useful flag for debugging!
self.__dict__.update(locals())
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single)
if profiler is None:
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single)
else:
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single,
profiler=profiler)
def __call__(self, *args):
return self.fn(*args)
......
......@@ -722,12 +722,11 @@ class OpWiseCLinker(Linker):
perform method if no C version can be generated.
"""
def __init__(self, env, profiler = None, fallback_on_perform = True):
def __init__(self, env, fallback_on_perform = True):
self.env = env
self.profiler = profiler
self.fallback_on_perform = fallback_on_perform
def make_thunk(self, inplace = False):
def make_thunk(self, inplace = False, profiler = None):
if inplace:
env = self.env
else:
......@@ -747,7 +746,7 @@ class OpWiseCLinker(Linker):
else:
raise
if self.profiler is None:
if profiler is None:
def f():
try:
for thunk, op in zip(thunks, op_order):
......@@ -755,7 +754,6 @@ class OpWiseCLinker(Linker):
except:
raise_with_op(op)
else:
profiler = self.profiler()
def f():
def g():
for thunk, op in zip(thunks, op_order):
......
......@@ -66,7 +66,7 @@ class Linker:
"""
raise AbstractFunctionError()
def make_function(self, inplace = False, unpack_single = True):
def make_function(self, inplace = False, unpack_single = True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
env used by this Linker and returns values corresponding the the outputs
......@@ -85,7 +85,7 @@ class Linker:
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
"""
thunk, inputs, outputs = self.make_thunk(inplace)
thunk, inputs, outputs = self.make_thunk(inplace, **kwargs)
def execute(*args):
def e_arity(takes, got):
......@@ -101,10 +101,8 @@ class Linker:
else:
return [result.data for result in outputs]
execute.thunk = thunk
try:
execute.profiler = thunk.profiler
except AttributeError:
pass
execute.inputs = inputs
execute.outputs = outputs
return execute
......@@ -117,18 +115,17 @@ class PerformLinker(Linker):
the env in the order given by env.toposort.
"""
def __init__(self, env, profiler = None):
def __init__(self, env):
self.env = env
self.profiler = profiler
def make_thunk(self, inplace = False):
def make_thunk(self, inplace = False, profiler = None):
if inplace:
env = self.env
else:
env = self.env.clone(True)
order = env.toposort()
thunks = [op.perform for op in order]
if self.profiler is None:
if profiler is None:
def f():
try:
for thunk, op in zip(thunks, order):
......@@ -136,7 +133,6 @@ class PerformLinker(Linker):
except:
raise_with_op(op)
else:
profiler = self.profiler()
def f():
def g():
for thunk, op in zip(thunks, order):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论