提交 e6af67eb authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Add type hints on base Linker and stop using undef

Admittedly, the `undef = object()` is a cool trick, but `allow_gc` is simply a null-able boolean. Also, the `make_function` method can probably be removed, so I marked it as deprecated.
上级 b079a36c
...@@ -1881,9 +1881,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1881,9 +1881,9 @@ class OpWiseCLinker(link.LocalLinker):
self.fgraph = None self.fgraph = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors self.nice_errors = nice_errors
self.allow_gc = allow_gc
if schedule: if schedule:
self.schedule = schedule self.schedule = schedule
super().__init__(allow_gc=allow_gc)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" """
......
...@@ -711,9 +711,9 @@ class PerformLinker(LocalLinker): ...@@ -711,9 +711,9 @@ class PerformLinker(LocalLinker):
if allow_gc is None: if allow_gc is None:
allow_gc = theano.config.allow_gc allow_gc = theano.config.allow_gc
self.fgraph = None self.fgraph = None
self.allow_gc = allow_gc
if schedule: if schedule:
self.schedule = schedule self.schedule = schedule
super().__init__(allow_gc=allow_gc)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
""" """
......
...@@ -771,7 +771,6 @@ class VM_Linker(link.LocalLinker): ...@@ -771,7 +771,6 @@ class VM_Linker(link.LocalLinker):
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
self.fgraph = None self.fgraph = None
self.allow_gc = allow_gc
self.use_cloop = use_cloop self.use_cloop = use_cloop
self.callback = callback self.callback = callback
self.callback_input = callback_input self.callback_input = callback_input
...@@ -783,6 +782,7 @@ class VM_Linker(link.LocalLinker): ...@@ -783,6 +782,7 @@ class VM_Linker(link.LocalLinker):
self.updated_vars = {} self.updated_vars = {}
if schedule: if schedule:
self.schedule = schedule self.schedule = schedule
super().__init__(allow_gc=allow_gc)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling=None, profile=None):
"""Check if fgraph is the first FunctionGraph that has ever been """Check if fgraph is the first FunctionGraph that has ever been
......
import typing
from copy import copy from copy import copy
from theano.gof import utils from theano.utils import deprecated
from theano.gof.utils import undef
class Linker: class Linker:
""" """
WRITEME Base type for all linkers.
A linker takes a FunctionGraph and turns it into a callable.
""" """
def clone(self, allow_gc=undef): def __init__(self, *, allow_gc: typing.Optional[bool] = None):
self._allow_gc = allow_gc
super().__init__()
@property
def allow_gc(self) -> typing.Optional[bool]:
"""Determines if the linker may allow garbage collection.
None means undefined.
"""
return self._allow_gc
def clone(self, allow_gc: typing.Optional[bool] = None):
new = copy(self) new = copy(self)
if allow_gc is not undef: if allow_gc is not None:
new.allow_gc = allow_gc new._allow_gc = allow_gc
return new return new
def make_thunk(self): def make_thunk(self):
...@@ -38,9 +51,11 @@ class Linker: ...@@ -38,9 +51,11 @@ class Linker:
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise utils.MethodNotDefined("make_thunk", type(self), self.__class__.__name__) raise NotImplementedError(
f"make_thunk method of {type(self)} is not implemented."
)
# DELETEME # @deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single=True, **kwargs): def make_function(self, unpack_single=True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
...@@ -62,6 +77,8 @@ class Linker: ...@@ -62,6 +77,8 @@ class Linker:
length 1 will be returned. length 1 will be returned.
""" """
from theano.gof import utils
thunk, inputs, outputs = self.make_thunk(**kwargs) thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args): def execute(*args):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论