提交 8bf588af authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Pass custom schedulers to Linker init instead of monkeypatching

Also adds some type hints.
上级 28b2c6f4
...@@ -1749,8 +1749,7 @@ class _Linker(link.LocalLinker): ...@@ -1749,8 +1749,7 @@ class _Linker(link.LocalLinker):
super(gof.LocalLinker, self).__init__() super(gof.LocalLinker, self).__init__()
self.fgraph = None self.fgraph = None
self.maker = maker self.maker = maker
if schedule: super().__init__(scheduler=schedule)
self.schedule = schedule
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
if no_recycling is None: if no_recycling is None:
......
...@@ -591,8 +591,7 @@ class CLinker(link.Linker): ...@@ -591,8 +591,7 @@ class CLinker(link.Linker):
def __init__(self, schedule=None): def __init__(self, schedule=None):
self.fgraph = None self.fgraph = None
if schedule: super().__init__(scheduler=schedule)
self.schedule = schedule
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" """
...@@ -1881,9 +1880,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1881,9 +1880,7 @@ class OpWiseCLinker(link.LocalLinker):
self.fgraph = None self.fgraph = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors self.nice_errors = nice_errors
if schedule: super().__init__(allow_gc=allow_gc, scheduler=schedule)
self.schedule = schedule
super().__init__(allow_gc=allow_gc)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" """
...@@ -2049,8 +2046,7 @@ class DualLinker(link.Linker): ...@@ -2049,8 +2046,7 @@ class DualLinker(link.Linker):
""" """
self.fgraph = None self.fgraph = None
self.checker = checker self.checker = checker
if schedule: super().__init__(scheduler=schedule)
self.schedule = schedule
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" """
......
...@@ -447,6 +447,14 @@ def map_storage(fgraph, order, input_storage, output_storage, storage_map=None): ...@@ -447,6 +447,14 @@ def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
return input_storage, output_storage, storage_map return input_storage, output_storage, storage_map
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
def streamline( def streamline(
fgraph, fgraph,
thunks, thunks,
...@@ -579,9 +587,7 @@ class PerformLinker(LocalLinker): ...@@ -579,9 +587,7 @@ class PerformLinker(LocalLinker):
if allow_gc is None: if allow_gc is None:
allow_gc = theano.config.allow_gc allow_gc = theano.config.allow_gc
self.fgraph = None self.fgraph = None
if schedule: super().__init__(allow_gc=allow_gc, scheduler=schedule)
self.schedule = schedule
super().__init__(allow_gc=allow_gc)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" """
...@@ -707,14 +713,6 @@ class PerformLinker(LocalLinker): ...@@ -707,14 +713,6 @@ class PerformLinker(LocalLinker):
) )
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
class WrapLinker(Linker): class WrapLinker(Linker):
""" """
This class makes it easier to run several L{LocalLinker}s in parallel, and This class makes it easier to run several L{LocalLinker}s in parallel, and
......
...@@ -780,9 +780,7 @@ class VM_Linker(link.LocalLinker): ...@@ -780,9 +780,7 @@ class VM_Linker(link.LocalLinker):
self.c_thunks = c_thunks self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval self.allow_partial_eval = allow_partial_eval
self.updated_vars = {} self.updated_vars = {}
if schedule: super().__init__(allow_gc=allow_gc, scheduler=schedule)
self.schedule = schedule
super().__init__(allow_gc=allow_gc)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
"""Check if fgraph is the first FunctionGraph that has ever been """Check if fgraph is the first FunctionGraph that has ever been
......
import typing import typing
from copy import copy, deepcopy from copy import copy, deepcopy
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.utils import deprecated from theano.utils import deprecated
...@@ -97,10 +99,10 @@ class Container: ...@@ -97,10 +99,10 @@ class Container:
r = type(self)( r = type(self)(
deepcopy(self.type, memo=memo), deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo), deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo), readonly=deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo), strict=deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo), allow_downcast=deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo), name=deepcopy(self.name, memo=memo),
) )
# Work around NumPy deepcopy of ndarray with 0 dimension that # Work around NumPy deepcopy of ndarray with 0 dimension that
# don't return an ndarray. # don't return an ndarray.
...@@ -120,10 +122,24 @@ class Linker: ...@@ -120,10 +122,24 @@ class Linker:
Base type for all linkers. Base type for all linkers.
A linker takes a FunctionGraph and turns it into a callable. A linker takes a FunctionGraph and turns it into a callable.
Parameters
----------
allow_gc : optional, bool
Configures if garbage collection is enabled.
scheduler : callable
A scheduling function that takes a FunctionGraph and returns a list of Apply nodes.
Defaults to the .toposort() method of the FunctionGraph.
""" """
def __init__(self, *, allow_gc: typing.Optional[bool] = None): def __init__(
self,
*,
allow_gc: typing.Optional[bool] = None,
scheduler: typing.Callable[[FunctionGraph], typing.List[Apply]] = None,
):
self._allow_gc = allow_gc self._allow_gc = allow_gc
self._scheduler = scheduler
super().__init__() super().__init__()
@property @property
...@@ -219,7 +235,21 @@ class Linker: ...@@ -219,7 +235,21 @@ class Linker:
return execute return execute
def schedule(self, fgraph): def schedule(self, fgraph: FunctionGraph) -> typing.List[Apply]:
"""Runs the scheduler (if set) or the toposort on the FunctionGraph.
Parameters
----------
fgraph : FunctionGraph
A graph to compute the schedule for.
Returns
-------
nodes : list of Apply nodes
The result of the scheduling or toposort operation.
"""
if callable(self._scheduler):
return self.scheduler(fgraph)
return fgraph.toposort() return fgraph.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论