提交 ec08d469 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add schedule to linker interface

上级 520fd3b2
......@@ -1547,10 +1547,11 @@ default_make_thunk = [theano.gof.Op.make_thunk.im_func,
class _Linker(gof.link.LocalLinker):
"""Special debugging linker"""
def __init__(self, maker):
def __init__(self, maker, schedule=None):
super(gof.LocalLinker, self).__init__()
self.fgraph = None
self.maker = maker
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
......
......@@ -402,8 +402,9 @@ class CLinker(link.Linker):
associated to it during the computation (to avoid reusing it).
"""
def __init__(self):
def __init__(self, schedule=None):
self.fgraph = None
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
"""WRITEME"""
......@@ -1396,13 +1397,15 @@ class OpWiseCLinker(link.LocalLinker):
def __init__(self,
fallback_on_perform=True,
allow_gc=None,
nice_errors=True):
nice_errors=True,
schedule=None):
if allow_gc is None:
allow_gc = config.allow_gc
self.fgraph = None
self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
self.allow_gc = allow_gc
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
......@@ -1522,7 +1525,7 @@ class DualLinker(link.Linker):
function.
"""
def __init__(self, checker=_default_checker):
def __init__(self, checker=_default_checker, schedule=None):
"""
Initialize a DualLinker.
......@@ -1547,6 +1550,7 @@ class DualLinker(link.Linker):
"""
self.fgraph = None
self.checker = checker
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
......@@ -1564,11 +1568,13 @@ class DualLinker(link.Linker):
fgraph = self.fgraph
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = link.PerformLinker().accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs)
_f, i1, o1, thunks1, order1 = (
link.PerformLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = OpWiseCLinker().accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = (
OpWiseCLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
def f():
for input1, input2 in izip(i1, i2):
......
......@@ -420,10 +420,11 @@ class PerformLinker(LocalLinker):
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def __init__(self, allow_gc=True):
def __init__(self, allow_gc=True, schedule=None):
#TODO: set allow_gc = True by default, when it works with the OpWiseCLinker
self.fgraph = None
self.allow_gc = allow_gc
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
"""
......@@ -533,7 +534,7 @@ class WrapLinker(Linker):
"""
def __init__(self, linkers, wrapper):
def __init__(self, linkers, wrapper, schedule=None):
"""
Initialize a WrapLinker.
......@@ -555,6 +556,7 @@ class WrapLinker(Linker):
self.fgraph = None
self.linkers = linkers
self.wrapper = wrapper
self.schedule = schedule or self.schedule
def __copy__(self):
"""
......@@ -569,7 +571,8 @@ class WrapLinker(Linker):
"""
other = self.__class__(
linkers=[copy(l) for l in self.linkers],
wrapper=self.wrapper)
wrapper=self.wrapper,
schedule=self.schedule)
return other
def accept(self, fgraph, no_recycling=None):
......
......@@ -508,7 +508,7 @@ class VM_Linker(link.LocalLinker):
"""
def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None):
lazy=None, schedule=None):
"""
allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on
......@@ -538,6 +538,7 @@ class VM_Linker(link.LocalLinker):
self.callback = callback
self.lazy = lazy
self.updated_vars = {}
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
"""
......@@ -559,7 +560,8 @@ class VM_Linker(link.LocalLinker):
allow_gc=self.allow_gc,
use_cloop=self.use_cloop,
callback=self.callback,
lazy=self.lazy
lazy=self.lazy,
schedule=self.schedule
).accept(fgraph, no_recycling)
self.fgraph = fgraph
self.no_recycling = no_recycling
......
......@@ -15,7 +15,8 @@ class DebugLinker(gof.WrapLinker):
copy_originals=False,
check_types=True,
compare_variables=True,
compare_fn=(lambda x, y: x == y)):
compare_fn=(lambda x, y: x == y),
schedule=None):
if debug_pre is None:
debug_pre = []
if debug_post is None:
......@@ -46,6 +47,8 @@ class DebugLinker(gof.WrapLinker):
if compare_variables is not None:
self.debug_post.append(self.compare_variables)
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
no_recycling = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论