提交 28b2c6f4 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Move Container and LocalLinker to link module

The comment said to move Container to compile, but it is directly and exclusively used by the Linkers. Some type hints were added and Container now requires kwarg specification for the optional arguments.
上级 785101c5
...@@ -5,9 +5,10 @@ import numpy as np ...@@ -5,9 +5,10 @@ import numpy as np
import theano import theano
from theano.gof import fg, graph from theano.gof import fg, graph
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable
from theano.gof.link import Container, PerformLinker, WrapLinker from theano.gof.link import PerformLinker, WrapLinker
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.link import Container
from theano.utils import cmp from theano.utils import cmp
......
...@@ -18,7 +18,7 @@ from warnings import warn ...@@ -18,7 +18,7 @@ from warnings import warn
import numpy as np import numpy as np
import theano import theano
from theano import config, gof from theano import config, gof, link
from theano.compile.function.types import ( from theano.compile.function.types import (
Function, Function,
FunctionMaker, FunctionMaker,
...@@ -27,7 +27,7 @@ from theano.compile.function.types import ( ...@@ -27,7 +27,7 @@ from theano.compile.function.types import (
) )
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard from theano.compile.ops import OutputGuard, _output_guard
from theano.gof import graph, link, ops_with_inner_function, utils from theano.gof import graph, ops_with_inner_function, utils
from theano.gof.link import raise_with_op from theano.gof.link import raise_with_op
from theano.utils import get_unbound_function from theano.utils import get_unbound_function
...@@ -1739,7 +1739,7 @@ class _DummyLinker: ...@@ -1739,7 +1739,7 @@ class _DummyLinker:
return self return self
class _Linker(gof.link.LocalLinker): class _Linker(link.LocalLinker):
""" """
Special debugging linker. Special debugging linker.
...@@ -1793,7 +1793,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1793,7 +1793,7 @@ class _Linker(gof.link.LocalLinker):
# the function's outputs will always be freshly allocated. # the function's outputs will always be freshly allocated.
no_recycling = [] no_recycling = []
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = theano.gof.link.map_storage(
fgraph, order, input_storage_, output_storage_, storage_map fgraph, order, input_storage_, output_storage_, storage_map
) )
......
...@@ -9,9 +9,9 @@ import logging ...@@ -9,9 +9,9 @@ import logging
import numpy as np import numpy as np
from theano.gof.graph import Variable from theano.gof.graph import Variable
from theano.gof.link import Container
from theano.gof.type import generic from theano.gof.type import generic
from theano.gof.utils import add_tag_trace from theano.gof.utils import add_tag_trace
from theano.link import Container
_logger = logging.getLogger("theano.compile.sharedvalue") _logger = logging.getLogger("theano.compile.sharedvalue")
......
...@@ -5,14 +5,7 @@ from theano.gof.cc import CLinker, DualLinker, HideC, OpWiseCLinker ...@@ -5,14 +5,7 @@ from theano.gof.cc import CLinker, DualLinker, HideC, OpWiseCLinker
from theano.gof.destroyhandler import DestroyHandler from theano.gof.destroyhandler import DestroyHandler
from theano.gof.fg import FunctionGraph, InconsistencyError, MissingInputError from theano.gof.fg import FunctionGraph, InconsistencyError, MissingInputError
from theano.gof.graph import Apply, Constant, Variable, view_roots from theano.gof.graph import Apply, Constant, Variable, view_roots
from theano.gof.link import ( from theano.gof.link import PerformLinker, WrapLinker, WrapLinkerMany
Container,
Linker,
LocalLinker,
PerformLinker,
WrapLinker,
WrapLinkerMany,
)
from theano.gof.op import ( from theano.gof.op import (
COp, COp,
Op, Op,
...@@ -54,6 +47,7 @@ from theano.gof.toolbox import ( ...@@ -54,6 +47,7 @@ from theano.gof.toolbox import (
) )
from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic
from theano.gof.utils import MethodNotDefined, hashtype, object2 from theano.gof.utils import MethodNotDefined, hashtype, object2
from theano.link import Container, Linker, LocalLinker
if theano.config.cmodule__preload_cache: if theano.config.cmodule__preload_cache:
......
...@@ -1214,7 +1214,7 @@ class CLinker(link.Linker): ...@@ -1214,7 +1214,7 @@ class CLinker(link.Linker):
for input, storage in zip(self.fgraph.inputs, input_storage) for input, storage in zip(self.fgraph.inputs, input_storage)
], ],
[ [
link.Container(output, storage, True) link.Container(output, storage, readonly=True)
for output, storage in zip(self.fgraph.outputs, output_storage) for output, storage in zip(self.fgraph.outputs, output_storage)
], ],
error_storage, error_storage,
...@@ -1987,7 +1987,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1987,7 +1987,7 @@ class OpWiseCLinker(link.LocalLinker):
for input, storage in zip(fgraph.inputs, input_storage) for input, storage in zip(fgraph.inputs, input_storage)
], ],
[ [
link.Container(output, storage, True) link.Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage) for output, storage in zip(fgraph.outputs, output_storage)
], ],
thunks, thunks,
......
import sys import sys
import traceback import traceback
from copy import copy, deepcopy from copy import copy
from io import StringIO from io import StringIO
from sys import getsizeof from sys import getsizeof
from warnings import warn from warnings import warn
...@@ -9,8 +9,7 @@ import numpy as np ...@@ -9,8 +9,7 @@ import numpy as np
import theano import theano
from theano.gof import graph, utils from theano.gof import graph, utils
from theano.gof.type import Type from theano.link.basic import Container, Linker, LocalLinker
from theano.link.basic import Linker
from .utils import undef from .utils import undef
...@@ -342,111 +341,6 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None): ...@@ -342,111 +341,6 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
# TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container:
"""
This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
Parameters
----------
r : a Variable or a Type
storage
A list of length 1, whose element is the value for `r`.
readonly : bool
True indicates that this should not be setable by Function[r] = val.
strict : bool
If True, we don't allow type casting.
allow_downcast
If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar.
name : str
A string (for pretty-printing?)
"""
def __init__(
self, r, storage, readonly=False, strict=False, allow_downcast=None, name=None
):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
# self.r = r
if isinstance(r, Type):
self.type = r
else:
self.type = r.type
if name is None:
# Some Type do not have a name field.
self.name = getattr(r, "name", None)
else:
self.name = name
self.storage = storage
self.readonly = readonly
self.strict = strict
self.allow_downcast = allow_downcast
def __get__(self):
return self.storage[0]
def __set__(self, value):
if self.readonly:
raise Exception(f"Cannot set readonly storage: {self.name}")
try:
if value is None:
self.storage[0] = None
return
kwargs = {}
if self.strict:
kwargs["strict"] = True
if self.allow_downcast is not None:
kwargs["allow_downcast"] = self.allow_downcast
if hasattr(self.type, "filter_inplace"):
self.storage[0] = self.type.filter_inplace(
value, self.storage[0], **kwargs
)
else:
self.storage[0] = self.type.filter(value, **kwargs)
except Exception as e:
e.args = e.args + (f'Container name "{self.name}"',)
raise
data = property(__get__, __set__)
value = property(__get__, __set__)
def __str__(self):
return "<" + str(self.storage[0]) + ">"
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
data_was_in_memo = id(self.storage[0]) in memo
r = type(self)(
deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
# Work around NumPy deepcopy of ndarray with 0 dimension that
# don't return an ndarray.
if r.storage[0] is not None and not self.type.is_valid_value(r.storage[0]):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
r.storage[0] = self.type.filter(
r.storage[0], strict=False, allow_downcast=False
)
memo[id(self.storage[0])] = r.storage[0]
return r
def map_storage(fgraph, order, input_storage, output_storage, storage_map=None): def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes. """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
...@@ -638,32 +532,6 @@ def streamline( ...@@ -638,32 +532,6 @@ def streamline(
return f return f
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
"""
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)[:3]
def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise utils.MethodNotDefined("make_all", type(self), self.__class__.__name__)
def gc_helper(node_list): def gc_helper(node_list):
""" """
Return the set of Variable instances which are computed by node_list. Return the set of Variable instances which are computed by node_list.
...@@ -831,7 +699,7 @@ class PerformLinker(LocalLinker): ...@@ -831,7 +699,7 @@ class PerformLinker(LocalLinker):
for input, storage in zip(fgraph.inputs, input_storage) for input, storage in zip(fgraph.inputs, input_storage)
], ],
[ [
Container(output, storage, True) Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage) for output, storage in zip(fgraph.outputs, output_storage)
], ],
thunks, thunks,
......
...@@ -1228,7 +1228,7 @@ class VM_Linker(link.LocalLinker): ...@@ -1228,7 +1228,7 @@ class VM_Linker(link.LocalLinker):
for input, storage in zip(fgraph.inputs, input_storage) for input, storage in zip(fgraph.inputs, input_storage)
], ],
[ [
link.Container(output, storage, True) link.Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage) for output, storage in zip(fgraph.outputs, output_storage)
], ],
thunks, thunks,
......
from theano.link.basic import Container, Linker, LocalLinker
import typing import typing
from copy import copy from copy import copy, deepcopy
from theano.gof.type import Type
from theano.utils import deprecated from theano.utils import deprecated
class Container:
"""
This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
Parameters
----------
r : a Variable or a Type
storage
A list of length 1, whose element is the value for `r`.
readonly : bool
True indicates that this should not be setable by Function[r] = val.
strict : bool
If True, we don't allow type casting.
allow_downcast
If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar.
name : str
A string (for pretty-printing?)
"""
def __init__(
self,
r,
storage,
*,
readonly=False,
strict=False,
allow_downcast=None,
name=None,
):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
if isinstance(r, Type):
self.type = r
else:
self.type = r.type
if name is None:
# Some Type do not have a name field.
self.name = getattr(r, "name", None)
else:
self.name = name
self.storage = storage
self.readonly = readonly
self.strict = strict
self.allow_downcast = allow_downcast
def __get__(self):
return self.storage[0]
def __set__(self, value):
if self.readonly:
raise Exception(f"Cannot set readonly storage: {self.name}")
try:
if value is None:
self.storage[0] = None
return
kwargs = {}
if self.strict:
kwargs["strict"] = True
if self.allow_downcast is not None:
kwargs["allow_downcast"] = self.allow_downcast
if hasattr(self.type, "filter_inplace"):
self.storage[0] = self.type.filter_inplace(
value, self.storage[0], **kwargs
)
else:
self.storage[0] = self.type.filter(value, **kwargs)
except Exception as e:
e.args = e.args + (f'Container name "{self.name}"',)
raise
data = property(__get__, __set__)
value = property(__get__, __set__)
def __str__(self):
return "<" + str(self.storage[0]) + ">"
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
data_was_in_memo = id(self.storage[0]) in memo
r = type(self)(
deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
# Work around NumPy deepcopy of ndarray with 0 dimension that
# don't return an ndarray.
if r.storage[0] is not None and not self.type.is_valid_value(r.storage[0]):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
r.storage[0] = self.type.filter(
r.storage[0], strict=False, allow_downcast=False
)
memo[id(self.storage[0])] = r.storage[0]
return r
class Linker: class Linker:
""" """
Base type for all linkers. Base type for all linkers.
...@@ -29,7 +140,13 @@ class Linker: ...@@ -29,7 +140,13 @@ class Linker:
new._allow_gc = allow_gc new._allow_gc = allow_gc
return new return new
def make_thunk(self): def make_thunk(
self,
) -> typing.Tuple[
typing.Callable[[], typing.NoReturn],
typing.List[Container],
typing.List[Container],
]:
""" """
This function must return a triplet (function, input_variables, This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the output_variables) where function is a thunk that operates on the
...@@ -104,3 +221,31 @@ class Linker: ...@@ -104,3 +221,31 @@ class Linker:
def schedule(self, fgraph): def schedule(self, fgraph):
return fgraph.toposort() return fgraph.toposort()
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
"""
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)[:3]
def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise NotImplementedError(
f"make_all method of {type(self)} is not implemented."
)
from collections.abc import Sequence from collections.abc import Sequence
from warnings import warn from warnings import warn
from theano.gof import utils
from theano.gof.graph import Constant from theano.gof.graph import Constant
from theano.gof.link import ( from theano.gof.link import (
Container,
PerformLinker, PerformLinker,
add_clear_storage, add_clear_storage,
gc_helper, gc_helper,
map_storage, map_storage,
streamline, streamline,
utils,
) )
from theano.link import Container
class JAXLinker(PerformLinker): class JAXLinker(PerformLinker):
...@@ -194,7 +194,7 @@ class JAXLinker(PerformLinker): ...@@ -194,7 +194,7 @@ class JAXLinker(PerformLinker):
for input, storage in zip(fgraph.inputs, input_storage) for input, storage in zip(fgraph.inputs, input_storage)
], ],
[ [
Container(output, storage, True) Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage) for output, storage in zip(fgraph.outputs, output_storage)
], ],
thunks, thunks,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论