提交 28b2c6f4 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Move Container and LocalLinker to link module

The comment said to move Container to compile, but it is directly and exclusively used by the Linkers. Some type hints were added and Container now requires kwarg specification for the optional arguments.
上级 785101c5
......@@ -5,9 +5,10 @@ import numpy as np
import theano
from theano.gof import fg, graph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.link import Container, PerformLinker, WrapLinker
from theano.gof.link import PerformLinker, WrapLinker
from theano.gof.op import Op
from theano.gof.type import Type
from theano.link import Container
from theano.utils import cmp
......
......@@ -18,7 +18,7 @@ from warnings import warn
import numpy as np
import theano
from theano import config, gof
from theano import config, gof, link
from theano.compile.function.types import (
Function,
FunctionMaker,
......@@ -27,7 +27,7 @@ from theano.compile.function.types import (
)
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard
from theano.gof import graph, link, ops_with_inner_function, utils
from theano.gof import graph, ops_with_inner_function, utils
from theano.gof.link import raise_with_op
from theano.utils import get_unbound_function
......@@ -1739,7 +1739,7 @@ class _DummyLinker:
return self
class _Linker(gof.link.LocalLinker):
class _Linker(link.LocalLinker):
"""
Special debugging linker.
......@@ -1793,7 +1793,7 @@ class _Linker(gof.link.LocalLinker):
# the function's outputs will always be freshly allocated.
no_recycling = []
input_storage, output_storage, storage_map = link.map_storage(
input_storage, output_storage, storage_map = theano.gof.link.map_storage(
fgraph, order, input_storage_, output_storage_, storage_map
)
......
......@@ -9,9 +9,9 @@ import logging
import numpy as np
from theano.gof.graph import Variable
from theano.gof.link import Container
from theano.gof.type import generic
from theano.gof.utils import add_tag_trace
from theano.link import Container
_logger = logging.getLogger("theano.compile.sharedvalue")
......
......@@ -5,14 +5,7 @@ from theano.gof.cc import CLinker, DualLinker, HideC, OpWiseCLinker
from theano.gof.destroyhandler import DestroyHandler
from theano.gof.fg import FunctionGraph, InconsistencyError, MissingInputError
from theano.gof.graph import Apply, Constant, Variable, view_roots
from theano.gof.link import (
Container,
Linker,
LocalLinker,
PerformLinker,
WrapLinker,
WrapLinkerMany,
)
from theano.gof.link import PerformLinker, WrapLinker, WrapLinkerMany
from theano.gof.op import (
COp,
Op,
......@@ -54,6 +47,7 @@ from theano.gof.toolbox import (
)
from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic
from theano.gof.utils import MethodNotDefined, hashtype, object2
from theano.link import Container, Linker, LocalLinker
if theano.config.cmodule__preload_cache:
......
......@@ -1214,7 +1214,7 @@ class CLinker(link.Linker):
for input, storage in zip(self.fgraph.inputs, input_storage)
],
[
link.Container(output, storage, True)
link.Container(output, storage, readonly=True)
for output, storage in zip(self.fgraph.outputs, output_storage)
],
error_storage,
......@@ -1987,7 +1987,7 @@ class OpWiseCLinker(link.LocalLinker):
for input, storage in zip(fgraph.inputs, input_storage)
],
[
link.Container(output, storage, True)
link.Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
......
import sys
import traceback
from copy import copy, deepcopy
from copy import copy
from io import StringIO
from sys import getsizeof
from warnings import warn
......@@ -9,8 +9,7 @@ import numpy as np
import theano
from theano.gof import graph, utils
from theano.gof.type import Type
from theano.link.basic import Linker
from theano.link.basic import Container, Linker, LocalLinker
from .utils import undef
......@@ -342,111 +341,6 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
raise exc_value.with_traceback(exc_trace)
# TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container:
"""
This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
Parameters
----------
r : a Variable or a Type
storage
A list of length 1, whose element is the value for `r`.
readonly : bool
True indicates that this should not be setable by Function[r] = val.
strict : bool
If True, we don't allow type casting.
allow_downcast
If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar.
name : str
A string (for pretty-printing?)
"""
def __init__(
self, r, storage, readonly=False, strict=False, allow_downcast=None, name=None
):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
# self.r = r
if isinstance(r, Type):
self.type = r
else:
self.type = r.type
if name is None:
# Some Type do not have a name field.
self.name = getattr(r, "name", None)
else:
self.name = name
self.storage = storage
self.readonly = readonly
self.strict = strict
self.allow_downcast = allow_downcast
def __get__(self):
return self.storage[0]
def __set__(self, value):
if self.readonly:
raise Exception(f"Cannot set readonly storage: {self.name}")
try:
if value is None:
self.storage[0] = None
return
kwargs = {}
if self.strict:
kwargs["strict"] = True
if self.allow_downcast is not None:
kwargs["allow_downcast"] = self.allow_downcast
if hasattr(self.type, "filter_inplace"):
self.storage[0] = self.type.filter_inplace(
value, self.storage[0], **kwargs
)
else:
self.storage[0] = self.type.filter(value, **kwargs)
except Exception as e:
e.args = e.args + (f'Container name "{self.name}"',)
raise
data = property(__get__, __set__)
value = property(__get__, __set__)
def __str__(self):
return "<" + str(self.storage[0]) + ">"
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
data_was_in_memo = id(self.storage[0]) in memo
r = type(self)(
deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
# Work around NumPy deepcopy of ndarray with 0 dimension that
# don't return an ndarray.
if r.storage[0] is not None and not self.type.is_valid_value(r.storage[0]):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
r.storage[0] = self.type.filter(
r.storage[0], strict=False, allow_downcast=False
)
memo[id(self.storage[0])] = r.storage[0]
return r
def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
......@@ -638,32 +532,6 @@ def streamline(
return f
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
"""
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)[:3]
def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise utils.MethodNotDefined("make_all", type(self), self.__class__.__name__)
def gc_helper(node_list):
"""
Return the set of Variable instances which are computed by node_list.
......@@ -831,7 +699,7 @@ class PerformLinker(LocalLinker):
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
......
......@@ -1228,7 +1228,7 @@ class VM_Linker(link.LocalLinker):
for input, storage in zip(fgraph.inputs, input_storage)
],
[
link.Container(output, storage, True)
link.Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
......
from theano.link.basic import Container, Linker, LocalLinker
import typing
from copy import copy
from copy import copy, deepcopy
from theano.gof.type import Type
from theano.utils import deprecated
class Container:
"""
This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
Parameters
----------
r : a Variable or a Type
storage
A list of length 1, whose element is the value for `r`.
readonly : bool
True indicates that this should not be setable by Function[r] = val.
strict : bool
If True, we don't allow type casting.
allow_downcast
If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar.
name : str
A string (for pretty-printing?)
"""
def __init__(
self,
r,
storage,
*,
readonly=False,
strict=False,
allow_downcast=None,
name=None,
):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
if isinstance(r, Type):
self.type = r
else:
self.type = r.type
if name is None:
# Some Type do not have a name field.
self.name = getattr(r, "name", None)
else:
self.name = name
self.storage = storage
self.readonly = readonly
self.strict = strict
self.allow_downcast = allow_downcast
def __get__(self):
return self.storage[0]
def __set__(self, value):
if self.readonly:
raise Exception(f"Cannot set readonly storage: {self.name}")
try:
if value is None:
self.storage[0] = None
return
kwargs = {}
if self.strict:
kwargs["strict"] = True
if self.allow_downcast is not None:
kwargs["allow_downcast"] = self.allow_downcast
if hasattr(self.type, "filter_inplace"):
self.storage[0] = self.type.filter_inplace(
value, self.storage[0], **kwargs
)
else:
self.storage[0] = self.type.filter(value, **kwargs)
except Exception as e:
e.args = e.args + (f'Container name "{self.name}"',)
raise
data = property(__get__, __set__)
value = property(__get__, __set__)
def __str__(self):
return "<" + str(self.storage[0]) + ">"
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
def __deepcopy__(self, memo):
data_was_in_memo = id(self.storage[0]) in memo
r = type(self)(
deepcopy(self.type, memo=memo),
deepcopy(self.storage, memo=memo),
deepcopy(self.readonly, memo=memo),
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
# Work around NumPy deepcopy of ndarray with 0 dimension that
# don't return an ndarray.
if r.storage[0] is not None and not self.type.is_valid_value(r.storage[0]):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
r.storage[0] = self.type.filter(
r.storage[0], strict=False, allow_downcast=False
)
memo[id(self.storage[0])] = r.storage[0]
return r
class Linker:
"""
Base type for all linkers.
......@@ -29,7 +140,13 @@ class Linker:
new._allow_gc = allow_gc
return new
def make_thunk(self):
def make_thunk(
self,
) -> typing.Tuple[
typing.Callable[[], typing.NoReturn],
typing.List[Container],
typing.List[Container],
]:
"""
This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the
......@@ -104,3 +221,31 @@ class Linker:
def schedule(self, fgraph):
return fgraph.toposort()
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
"""
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)[:3]
def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
#
# This function should return a tuple of 5 things
# 1. function to run the program
# 2. input storage
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise NotImplementedError(
f"make_all method of {type(self)} is not implemented."
)
from collections.abc import Sequence
from warnings import warn
from theano.gof import utils
from theano.gof.graph import Constant
from theano.gof.link import (
Container,
PerformLinker,
add_clear_storage,
gc_helper,
map_storage,
streamline,
utils,
)
from theano.link import Container
class JAXLinker(PerformLinker):
......@@ -194,7 +194,7 @@ class JAXLinker(PerformLinker):
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论