提交 ec08d469 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add schedule to linker interface

上级 520fd3b2
...@@ -1547,10 +1547,11 @@ default_make_thunk = [theano.gof.Op.make_thunk.im_func, ...@@ -1547,10 +1547,11 @@ default_make_thunk = [theano.gof.Op.make_thunk.im_func,
class _Linker(gof.link.LocalLinker): class _Linker(gof.link.LocalLinker):
"""Special debugging linker""" """Special debugging linker"""
def __init__(self, maker): def __init__(self, maker, schedule=None):
super(gof.LocalLinker, self).__init__() super(gof.LocalLinker, self).__init__()
self.fgraph = None self.fgraph = None
self.maker = maker self.maker = maker
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
......
...@@ -402,8 +402,9 @@ class CLinker(link.Linker): ...@@ -402,8 +402,9 @@ class CLinker(link.Linker):
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
def __init__(self): def __init__(self, schedule=None):
self.fgraph = None self.fgraph = None
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
"""WRITEME""" """WRITEME"""
...@@ -1396,13 +1397,15 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1396,13 +1397,15 @@ class OpWiseCLinker(link.LocalLinker):
def __init__(self, def __init__(self,
fallback_on_perform=True, fallback_on_perform=True,
allow_gc=None, allow_gc=None,
nice_errors=True): nice_errors=True,
schedule=None):
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
self.fgraph = None self.fgraph = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors self.nice_errors = nice_errors
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
...@@ -1522,7 +1525,7 @@ class DualLinker(link.Linker): ...@@ -1522,7 +1525,7 @@ class DualLinker(link.Linker):
function. function.
""" """
def __init__(self, checker=_default_checker): def __init__(self, checker=_default_checker, schedule=None):
""" """
Initialize a DualLinker. Initialize a DualLinker.
...@@ -1547,6 +1550,7 @@ class DualLinker(link.Linker): ...@@ -1547,6 +1550,7 @@ class DualLinker(link.Linker):
""" """
self.fgraph = None self.fgraph = None
self.checker = checker self.checker = checker
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
...@@ -1564,11 +1568,13 @@ class DualLinker(link.Linker): ...@@ -1564,11 +1568,13 @@ class DualLinker(link.Linker):
fgraph = self.fgraph fgraph = self.fgraph
no_recycling = self.no_recycling no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = link.PerformLinker().accept(fgraph, _f, i1, o1, thunks1, order1 = (
no_recycling=no_recycling).make_all(**kwargs) link.PerformLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None) kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = OpWiseCLinker().accept(fgraph, _f, i2, o2, thunks2, order2 = (
no_recycling=no_recycling).make_all(**kwargs) OpWiseCLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
def f(): def f():
for input1, input2 in izip(i1, i2): for input1, input2 in izip(i1, i2):
......
...@@ -420,10 +420,11 @@ class PerformLinker(LocalLinker): ...@@ -420,10 +420,11 @@ class PerformLinker(LocalLinker):
the L{FunctionGraph} in the order given by L{Linker.schedule}. the L{FunctionGraph} in the order given by L{Linker.schedule}.
""" """
def __init__(self, allow_gc=True): def __init__(self, allow_gc=True, schedule=None):
#TODO: set allow_gc = True by default, when it works with the OpWiseCLinker #TODO: set allow_gc = True by default, when it works with the OpWiseCLinker
self.fgraph = None self.fgraph = None
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
...@@ -533,7 +534,7 @@ class WrapLinker(Linker): ...@@ -533,7 +534,7 @@ class WrapLinker(Linker):
""" """
def __init__(self, linkers, wrapper): def __init__(self, linkers, wrapper, schedule=None):
""" """
Initialize a WrapLinker. Initialize a WrapLinker.
...@@ -555,6 +556,7 @@ class WrapLinker(Linker): ...@@ -555,6 +556,7 @@ class WrapLinker(Linker):
self.fgraph = None self.fgraph = None
self.linkers = linkers self.linkers = linkers
self.wrapper = wrapper self.wrapper = wrapper
self.schedule = schedule or self.schedule
def __copy__(self): def __copy__(self):
""" """
...@@ -569,7 +571,8 @@ class WrapLinker(Linker): ...@@ -569,7 +571,8 @@ class WrapLinker(Linker):
""" """
other = self.__class__( other = self.__class__(
linkers=[copy(l) for l in self.linkers], linkers=[copy(l) for l in self.linkers],
wrapper=self.wrapper) wrapper=self.wrapper,
schedule=self.schedule)
return other return other
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
......
...@@ -508,7 +508,7 @@ class VM_Linker(link.LocalLinker): ...@@ -508,7 +508,7 @@ class VM_Linker(link.LocalLinker):
""" """
def __init__(self, allow_gc=None, use_cloop=False, callback=None, def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None): lazy=None, schedule=None):
""" """
allow_gc - force the virtual machine to clean up unnecessary allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on references, in order to allow garbage collection on
...@@ -538,6 +538,7 @@ class VM_Linker(link.LocalLinker): ...@@ -538,6 +538,7 @@ class VM_Linker(link.LocalLinker):
self.callback = callback self.callback = callback
self.lazy = lazy self.lazy = lazy
self.updated_vars = {} self.updated_vars = {}
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
...@@ -559,7 +560,8 @@ class VM_Linker(link.LocalLinker): ...@@ -559,7 +560,8 @@ class VM_Linker(link.LocalLinker):
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
use_cloop=self.use_cloop, use_cloop=self.use_cloop,
callback=self.callback, callback=self.callback,
lazy=self.lazy lazy=self.lazy,
schedule=self.schedule
).accept(fgraph, no_recycling) ).accept(fgraph, no_recycling)
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
......
...@@ -15,7 +15,8 @@ class DebugLinker(gof.WrapLinker): ...@@ -15,7 +15,8 @@ class DebugLinker(gof.WrapLinker):
copy_originals=False, copy_originals=False,
check_types=True, check_types=True,
compare_variables=True, compare_variables=True,
compare_fn=(lambda x, y: x == y)): compare_fn=(lambda x, y: x == y),
schedule=None):
if debug_pre is None: if debug_pre is None:
debug_pre = [] debug_pre = []
if debug_post is None: if debug_post is None:
...@@ -46,6 +47,8 @@ class DebugLinker(gof.WrapLinker): ...@@ -46,6 +47,8 @@ class DebugLinker(gof.WrapLinker):
if compare_variables is not None: if compare_variables is not None:
self.debug_post.append(self.compare_variables) self.debug_post.append(self.compare_variables)
self.schedule = schedule or self.schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论