提交 8bf588af authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Pass custom schedulers to Linker init instead of monkeypatching

Also adds some type hints.
上级 28b2c6f4
......@@ -1749,8 +1749,7 @@ class _Linker(link.LocalLinker):
super(gof.LocalLinker, self).__init__()
self.fgraph = None
self.maker = maker
if schedule:
self.schedule = schedule
super().__init__(scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
if no_recycling is None:
......
......@@ -591,8 +591,7 @@ class CLinker(link.Linker):
def __init__(self, schedule=None):
self.fgraph = None
if schedule:
self.schedule = schedule
super().__init__(scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
......@@ -1881,9 +1880,7 @@ class OpWiseCLinker(link.LocalLinker):
self.fgraph = None
self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
if schedule:
self.schedule = schedule
super().__init__(allow_gc=allow_gc)
super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
......@@ -2049,8 +2046,7 @@ class DualLinker(link.Linker):
"""
self.fgraph = None
self.checker = checker
if schedule:
self.schedule = schedule
super().__init__(scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
......
......@@ -447,6 +447,14 @@ def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
return input_storage, output_storage, storage_map
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
def streamline(
fgraph,
thunks,
......@@ -579,9 +587,7 @@ class PerformLinker(LocalLinker):
if allow_gc is None:
allow_gc = theano.config.allow_gc
self.fgraph = None
if schedule:
self.schedule = schedule
super().__init__(allow_gc=allow_gc)
super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
......@@ -707,14 +713,6 @@ class PerformLinker(LocalLinker):
)
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
class WrapLinker(Linker):
"""
This class makes it easier to run several L{LocalLinker}s in parallel, and
......
......@@ -780,9 +780,7 @@ class VM_Linker(link.LocalLinker):
self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval
self.updated_vars = {}
if schedule:
self.schedule = schedule
super().__init__(allow_gc=allow_gc)
super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""Check if fgraph is the first FunctionGraph that has ever been
......
import typing
from copy import copy, deepcopy
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply
from theano.gof.type import Type
from theano.utils import deprecated
......@@ -97,10 +99,10 @@ class Container:
r = type(self)(
deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
readonly=deepcopy(self.readonly, memo=memo),
strict=deepcopy(self.strict, memo=memo),
allow_downcast=deepcopy(self.allow_downcast, memo=memo),
name=deepcopy(self.name, memo=memo),
)
# Work around NumPy deepcopy of ndarray with 0 dimension that
# don't return an ndarray.
......@@ -120,10 +122,24 @@ class Linker:
Base type for all linkers.
A linker takes a FunctionGraph and turns it into a callable.
Parameters
----------
allow_gc : optional, bool
Configures if garbage collection is enabled.
scheduler : callable
A scheduling function that takes a FunctionGraph and returns a list of Apply nodes.
Defaults to the .toposort() method of the FunctionGraph.
"""
def __init__(self, *, allow_gc: typing.Optional[bool] = None):
def __init__(
self,
*,
allow_gc: typing.Optional[bool] = None,
scheduler: typing.Callable[[FunctionGraph], typing.List[Apply]] = None,
):
self._allow_gc = allow_gc
self._scheduler = scheduler
super().__init__()
@property
......@@ -219,7 +235,21 @@ class Linker:
return execute
def schedule(self, fgraph):
def schedule(self, fgraph: FunctionGraph) -> typing.List[Apply]:
"""Runs the scheduler (if set) or the toposort on the FunctionGraph.
Parameters
----------
fgraph : FunctionGraph
A graph to compute the schedule for.
Returns
-------
nodes : list of Apply nodes
The result of the scheduling or toposort operation.
"""
if callable(self._scheduler):
return self.scheduler(fgraph)
return fgraph.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论