提交 5bb459cf authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Migrate all of theano.gof.link to theano.link sub-package

上级 8bf588af
......@@ -11,7 +11,7 @@ purpose of it is to hack it to investigate what your own particular program is d
.. code-block:: python
from theano.gof.link import WrapLinkerMany
from theano.link import WrapLinkerMany
from theano import config
from theano.compile.mode import (Mode, register_mode, predefined_modes, predefined_linkers,
predefined_optimizers)
......
......@@ -5,9 +5,9 @@ import theano
from theano.gof import fg
from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.link import PerformLinker
from theano.gof.op import Op
from theano.gof.type import Type
from theano.link import PerformLinker
def as_variable(x):
......
......@@ -5,10 +5,9 @@ import numpy as np
import theano
from theano.gof import fg, graph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.link import PerformLinker, WrapLinker
from theano.gof.op import Op
from theano.gof.type import Type
from theano.link import Container
from theano.link import Container, PerformLinker, WrapLinker
from theano.utils import cmp
......
......@@ -6,6 +6,7 @@ from tests import unittest_tools as utt
from theano import config, function, scalar
from theano.gof import FunctionGraph
from theano.gof.opt import out2in
from theano.link.basic import PerformLinker
# from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -158,7 +159,7 @@ def test_local_dimshuffle_alloc():
g = FunctionGraph([x], [out])
reshape_dimshuffle(g)
l = theano.gof.PerformLinker()
l = PerformLinker()
l.accept(g)
f = l.make_function()
......
......@@ -28,7 +28,8 @@ from theano.compile.function.types import (
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard
from theano.gof import graph, ops_with_inner_function, utils
from theano.gof.link import raise_with_op
from theano.link.basic import LocalLinker
from theano.link.debugging import raise_with_op
from theano.utils import get_unbound_function
......@@ -1739,14 +1740,14 @@ class _DummyLinker:
return self
class _Linker(link.LocalLinker):
class _Linker(LocalLinker):
"""
Special debugging linker.
"""
def __init__(self, maker, schedule=None):
super(gof.LocalLinker, self).__init__()
super().__init__()
self.fgraph = None
self.maker = maker
super().__init__(scheduler=schedule)
......@@ -1792,7 +1793,7 @@ class _Linker(link.LocalLinker):
# the function's outputs will always be freshly allocated.
no_recycling = []
input_storage, output_storage, storage_map = theano.gof.link.map_storage(
input_storage, output_storage, storage_map = theano.link.map_storage(
fgraph, order, input_storage_, output_storage_, storage_map
)
......
......@@ -15,7 +15,7 @@ import numpy as np
import theano
import theano.compile.profiling
from theano import config, gof
from theano import config, gof, link
from theano.compile.io import In, SymbolicInput, SymbolicOutput
from theano.compile.ops import deep_copy_op, view_op
from theano.gof import graph
......@@ -978,7 +978,7 @@ class Function:
thunk = None
if hasattr(self.fn, "thunks"):
thunk = self.fn.thunks[self.fn.position_of_error]
gof.link.raise_with_op(
link.raise_with_op(
self.maker.fgraph,
node=self.fn.nodes[self.fn.position_of_error],
thunk=thunk,
......
......@@ -5,7 +5,6 @@ from theano.gof.cc import CLinker, DualLinker, HideC, OpWiseCLinker
from theano.gof.destroyhandler import DestroyHandler
from theano.gof.fg import FunctionGraph, InconsistencyError, MissingInputError
from theano.gof.graph import Apply, Constant, Variable, view_roots
from theano.gof.link import PerformLinker, WrapLinker, WrapLinkerMany
from theano.gof.op import (
COp,
Op,
......@@ -47,7 +46,14 @@ from theano.gof.toolbox import (
)
from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic
from theano.gof.utils import MethodNotDefined, hashtype, object2
from theano.link import Container, Linker, LocalLinker
from theano.link import (
Container,
Linker,
LocalLinker,
PerformLinker,
WrapLinker,
WrapLinkerMany,
)
if theano.config.cmodule__preload_cache:
......
......@@ -11,8 +11,8 @@ from io import StringIO
import numpy as np
from theano import config
from theano.gof import cmodule, graph, link, utils
from theano import config, link
from theano.gof import cmodule, graph, utils
from theano.gof.callcache import CallCache
from theano.gof.compilelock import get_lock, release_lock
......
......@@ -13,9 +13,7 @@ import warnings
from collections import defaultdict
import theano.gof.cmodule
from theano import config
from . import link
from theano import config, link
logger = logging.getLogger(__name__)
......
from theano.link.basic import Container, Linker, LocalLinker
import sys
from theano.link.basic import (
Container,
Linker,
LocalLinker,
PerformLinker,
WrapLinker,
WrapLinkerMany,
gc_helper,
map_storage,
streamline,
)
from theano.link.debugging import raise_with_op, set_excepthook
set_excepthook(handler=sys.stdout)
import typing
from copy import copy, deepcopy
from theano import config, utils
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply
from theano.gof.graph import Apply, Constant
from theano.gof.type import Type
from theano.utils import deprecated
from theano.gof.utils import to_return_values
from theano.link.debugging import raise_with_op
class Container:
......@@ -188,7 +190,7 @@ class Linker:
f"make_thunk method of {type(self)} is not implemented."
)
@deprecated("Marked for deletion. Only tests use it.")
@utils.deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
......@@ -210,8 +212,6 @@ class Linker:
length 1 will be returned.
"""
from theano.gof import utils
thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args):
......@@ -225,7 +225,7 @@ class Linker:
variable.data = arg
thunk()
if unpack_single:
return utils.to_return_values([variable.data for variable in outputs])
return to_return_values([variable.data for variable in outputs])
else:
return [variable.data for variable in outputs]
......@@ -249,7 +249,7 @@ class Linker:
The result of the scheduling or toposort operation.
"""
if callable(self._scheduler):
return self.scheduler(fgraph)
return self._scheduler(fgraph)
return fgraph.toposort()
......@@ -279,3 +279,535 @@ class LocalLinker(Linker):
raise NotImplementedError(
f"make_all method of {type(self)} is not implemented."
)
def map_storage(
fgraph: FunctionGraph,
order: typing.Iterable[Apply],
input_storage: typing.Optional[typing.List],
output_storage: typing.Optional[typing.List],
storage_map: typing.Dict = None,
) -> typing.Tuple[typing.List, typing.List, typing.Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs.
if input_storage is None:
input_storage = [[None] for input in fgraph.inputs]
else:
assert len(fgraph.inputs) == len(input_storage)
# add input storage into storage_map
for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage,
"Storage in storage_ma" "p: ",
storage_map[r],
)
else:
storage_map[r] = storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage)
for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ",
storage,
"Sto" "rage in storage_map: ",
storage_map[r],
)
else:
storage_map[r] = storage
# allocate storage for intermediate computation
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, Constant)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
for r in fgraph.outputs:
if isinstance(r, Constant):
storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs]
return input_storage, output_storage, storage_map
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
def streamline(
fgraph: FunctionGraph,
thunks,
order,
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
) -> typing.Callable[[], typing.NoReturn]:
"""
WRITEME
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if no_recycling is None:
no_recycling = []
if len(thunks) != len(order):
raise ValueError(
"Length of thunks and order must match", (len(thunks), len(order))
)
if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage):
raise ValueError(
"Length of thunks and post_thunk_old_storage must match",
(len(thunks), len(post_thunk_old_storage)),
)
def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage
):
thunk()
for old_s in old_storage:
old_s[0] = None
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_default_f
elif nice_errors:
def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_nice_errors_f
else:
# don't worry about raise_with_op, just go a little faster.
# there is a mix of python and c thunks
def streamline_fast_f():
for x in no_recycling:
x[0] = None
for thunk in thunks:
thunk()
f = streamline_fast_f
return f
def gc_helper(node_list: typing.List[Apply]):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
"""
# for freeing memory
last_user = {}
computed = set()
for node in node_list:
for input in node.inputs:
last_user[input] = node
for output in node.outputs:
computed.add(output)
return computed, last_user
class PerformLinker(LocalLinker):
"""
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def __init__(self, allow_gc=None, schedule=None):
if allow_gc is None:
allow_gc = config.allow_gc
self.fgraph = None
super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a time.
no_recycling
WRITEME
Returns
-------
object
self (TODO: WHY? Who calls this function?)
"""
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(
fgraph, no_recycling, profile
)
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph
self.no_recycling = no_recycling
return self
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
fgraph = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = []
for node in order:
# Maker sure we don't use C version of the code, but rather only
# the python version
# Note : ops that implement their own make thunk don't usually
# have this attribute defiend !!
thunks += [
node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py")
]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
computed, last_user = gc_helper(order)
if self.allow_gc:
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
for node in order:
if self.allow_gc:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
if no_recycling is True:
# True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
# The function that actually runs your program is one of the f's in streamline.
f = streamline(
fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling
)
f.allow_gc = (
self.allow_gc
) # HACK: this is a way of passing an arg to Function.__call__
add_clear_storage(f, computed, storage_map)
f.storage_map = storage_map
return (
f,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
order,
)
class WrapLinker(Linker):
"""
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
The constructor initializes a WrapLinker.
Parameters
----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order.
For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper : lambda (fgraph, i, i_node, i_thunk1, i_thunk2, ...) : None
Does some user-defined action for the i'th element of the program.
i_thunk<n> is the thunk returned by the n'th linker. (If you want
to run the program, make sure to call the necessary thunks in this
function.)
Notes
-----
The outputs of the first linker will be returned.
This linker ensures that each linker has its own storage for inputs and
outputs and intermediate variables. There is no interference between
linkers.
"""
def __init__(self, linkers, wrapper):
self.fgraph = None
self.linkers = linkers
self.wrapper = wrapper
def __copy__(self):
"""
Shallow copy of a WrapLinker.
Returns
-------
object
A copy of self, where each of the linkers in self.linkers
have been shallow-copied.
It is useful because in FunctionMaker, copy.copy is called on the
Mode's linker, so that it is not modified inplace when linker.accept()
is called. In this case, we want the wrapped linkers to be copied too.
"""
other = self.__class__(
linkers=[copy(x) for x in self.linkers], wrapper=self.wrapper
)
return other
def clone(self, allow_gc=None):
return self.__class__(
linkers=[x.clone(allow_gc=allow_gc) for x in self.linkers],
wrapper=self.wrapper,
)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
Parameters
----------
fgraph : gof.FunctionGraph
The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers)
during the computation to avoid reusing it.
"""
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(self.linkers, self.wrapper).accept(fgraph, no_recycling)
self.fgraph = fgraph
self.no_recycling = no_recycling
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers]
return self
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs):
no_recycling = self.no_recycling
make_all = [self.linkers[0].make_all(**kwargs)]
kwargs.pop("input_storage", None)
make_all += [x.make_all(**kwargs) for x in self.linkers[1:]]
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*make_all)
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
if not order_list0 == order_list:
raise Exception(
"All linkers to WrapLinker should execute operations in the same order."
)
inputs0 = input_lists[0]
outputs0 = output_lists[0]
thunk_groups = list(zip(*thunk_lists))
order = [x[0] for x in zip(*order_lists)]
to_reset = []
for thunks, node in zip(thunk_groups, order):
for j, output in enumerate(node.outputs):
if output in no_recycling:
for thunk in thunks:
to_reset.append(thunk.outputs[j])
wrapper = self.wrapper
pre = self.pre
def f():
for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(self.fgraph, i, node, *thunks)
except Exception:
raise_with_op(self.fgraph, node, *thunks)
f.thunk_groups = thunk_groups
return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers):
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
"""
def wrapper(*args):
for f in wrappers:
f(*args)
return WrapLinker(linkers, wrapper)
import io
import sys
import traceback
from copy import copy
from io import StringIO
from sys import getsizeof
from warnings import warn
import warnings
from operator import itemgetter
import numpy as np
import theano
from theano.gof import graph, utils
from theano.link.basic import Container, Linker, LocalLinker
from theano import config
from theano.gof.fg import FunctionGraph
from .utils import undef
__excepthook = sys.excepthook
def log_thunk_trace(value, f=sys.stderr):
def __log_thunk_trace(value, handler):
"""
Log Theano's diagnostic stack trace for an exception
raised by raise_with_op.
"""
# in future, consider accepting `write` as arg rather than file
# to support writing to a logger
def write(msg):
print(f"log_thunk_trace: {msg.strip()}", file=f)
print(f"log_thunk_trace: {msg.strip()}", file=handler)
if hasattr(value, "__thunk_trace__"):
trace2 = value.__thunk_trace__
......@@ -51,40 +42,41 @@ def log_thunk_trace(value, f=sys.stderr):
)
def thunk_hook(type, value, trace):
"""
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
def set_excepthook(handler: io.TextIOWrapper):
def thunk_hook(type, value, trace):
"""
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
The normal excepthook is then called.
The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes
-----
This hook replaced in testing, so it does not run.
"""
log_thunk_trace(value)
__excepthook(type, value, trace)
Notes
-----
This hook replaced in testing, so it does not run.
"""
__log_thunk_trace(value, handler=handler)
sys.__excepthook__(type, value, trace)
sys.excepthook = thunk_hook
sys.excepthook = thunk_hook
# TODO: Make this work with linker defined schedule
def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
def raise_with_op(
fgraph: FunctionGraph, node, thunk=None, exc_info=None, storage_map=None
):
"""
Re-raise an exception while annotating the exception object with
debug info.
......@@ -117,7 +109,10 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
The exception is not annotated if it is of type `KeyboardInterrupt`.
TODO: Make this work with linker defined schedule
"""
verbosity = config.exception_verbosity
if exc_info is None:
exc_info = sys.exc_info()
exc_type, exc_value, exc_trace = exc_info
......@@ -168,7 +163,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
+ f"\nInputs strides: {strides}"
+ f"\nInputs values: {scalar_values}"
)
if theano.config.exception_verbosity == "high":
if verbosity == "high":
detailed_err_msg += "\nInputs type_num: %s" % str(
[getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs]
)
......@@ -188,7 +183,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg += "\nBacktrace when the node is created(use Theano flag traceback__limit=N to make it longer):\n"
# Print separate message for each element in the list of batcktraces
sio = StringIO()
sio = io.StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
detailed_err_msg += str(sio.getvalue())
......@@ -201,15 +196,17 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
" Theano optimizations can be disabled with 'optimizer=None'."
)
if theano.config.exception_verbosity == "high":
if verbosity == "high":
import theano.printing
f = StringIO()
f = io.StringIO()
theano.printing.debugprint(node, file=f, stop_on_name=True, print_type=True)
detailed_err_msg += "\nDebugprint of the apply node: \n"
detailed_err_msg += f.getvalue()
# Prints output_map
if theano.config.exception_verbosity == "high" and storage_map is not None:
if verbosity == "high" and storage_map is not None:
detailed_err_msg += "\nStorage map footprint:\n"
shared_input_list = [
item
......@@ -283,7 +280,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
if k.type.may_share_memory(data, input_data):
total_size -= sz
else:
bytes = getsizeof(storage_map[k][0])
bytes = sys.getsizeof(storage_map[k][0])
storage_map_item.append(bytes)
storage_map_item.append(-1)
......@@ -297,8 +294,6 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
storage_map_item.append(None)
storage_map_list.append(storage_map_item)
from operator import itemgetter
storage_map_list.sort(key=itemgetter(3), reverse=True)
for item in storage_map_list:
if item[3] == -1:
......@@ -317,11 +312,11 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: {} Byte(s) {:.3f} GB\n".format(
total_size,
total_size / 1024.0 / 1024 / 1024,
total_size / 1024 / 1024 / 1024,
)
detailed_err_msg += " TotalSize inputs: {} Byte(s) {:.3f} GB\n".format(
total_size_inputs,
total_size_inputs / 1024.0 / 1024 / 1024,
total_size_inputs / 1024 / 1024 / 1024,
)
else:
......@@ -335,533 +330,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
str(exc_value) + detailed_err_msg + "\n" + "\n".join(hints)
)
except TypeError:
warn(f"{exc_type} error does not allow us to add extra error message")
warnings.warn(f"{exc_type} error does not allow us to add extra error message")
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
raise exc_value.with_traceback(exc_trace)
def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs.
if input_storage is None:
input_storage = [[None] for input in fgraph.inputs]
else:
assert len(fgraph.inputs) == len(input_storage)
# add input storage into storage_map
for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage,
"Storage in storage_ma" "p: ",
storage_map[r],
)
else:
storage_map[r] = storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage)
for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ",
storage,
"Sto" "rage in storage_map: ",
storage_map[r],
)
else:
storage_map[r] = storage
# allocate storage for intermediate computation
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, graph.Constant)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
for r in fgraph.outputs:
if isinstance(r, graph.Constant):
storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs]
return input_storage, output_storage, storage_map
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
def streamline(
fgraph,
thunks,
order,
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
):
"""
WRITEME
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if no_recycling is None:
no_recycling = []
if len(thunks) != len(order):
raise ValueError(
"Length of thunks and order must match", (len(thunks), len(order))
)
if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage):
raise ValueError(
"Length of thunks and post_thunk_old_storage must match",
(len(thunks), len(post_thunk_old_storage)),
)
def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage
):
thunk()
for old_s in old_storage:
old_s[0] = None
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_default_f
elif nice_errors:
def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_nice_errors_f
else:
# don't worry about raise_with_op, just go a little faster.
# there is a mix of python and c thunks
def streamline_fast_f():
for x in no_recycling:
x[0] = None
for thunk in thunks:
thunk()
f = streamline_fast_f
return f
def gc_helper(node_list):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
"""
# for freeing memory
last_user = {}
computed = set()
for node in node_list:
for input in node.inputs:
last_user[input] = node
for output in node.outputs:
computed.add(output)
return computed, last_user
class PerformLinker(LocalLinker):
"""
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def __init__(self, allow_gc=None, schedule=None):
if allow_gc is None:
allow_gc = theano.config.allow_gc
self.fgraph = None
super().__init__(allow_gc=allow_gc, scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a time.
no_recycling
WRITEME
Returns
-------
object
self (TODO: WHY? Who calls this function?)
"""
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(
fgraph, no_recycling, profile
)
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph
self.no_recycling = no_recycling
return self
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
fgraph = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = []
for node in order:
# Maker sure we don't use C version of the code, but rather only
# the python version
# Note : ops that implement their own make thunk don't usually
# have this attribute defiend !!
thunks += [
node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py")
]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
computed, last_user = gc_helper(order)
if self.allow_gc:
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
for node in order:
if self.allow_gc:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
if no_recycling is True:
# True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
# The function that actually runs your program is one of the f's in streamline.
f = streamline(
fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling
)
f.allow_gc = (
self.allow_gc
) # HACK: this is a way of passing an arg to Function.__call__
add_clear_storage(f, computed, storage_map)
f.storage_map = storage_map
return (
f,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
order,
)
class WrapLinker(Linker):
"""
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
The constructor initializes a WrapLinker.
Parameters
----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order.
For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper : lambda (fgraph, i, i_node, i_thunk1, i_thunk2, ...) : None
Does some user-defined action for the i'th element of the program.
i_thunk<n> is the thunk returned by the n'th linker. (If you want
to run the program, make sure to call the necessary thunks in this
function.)
Notes
-----
The outputs of the first linker will be returned.
This linker ensures that each linker has its own storage for inputs and
outputs and intermediate variables. There is no interference between
linkers.
"""
def __init__(self, linkers, wrapper):
self.fgraph = None
self.linkers = linkers
self.wrapper = wrapper
def __copy__(self):
"""
Shallow copy of a WrapLinker.
Returns
-------
object
A copy of self, where each of the linkers in self.linkers
have been shallow-copied.
It is useful because in FunctionMaker, copy.copy is called on the
Mode's linker, so that it is not modified inplace when linker.accept()
is called. In this case, we want the wrapped linkers to be copied too.
"""
other = self.__class__(
linkers=[copy(x) for x in self.linkers], wrapper=self.wrapper
)
return other
def clone(self, allow_gc=undef):
return self.__class__(
linkers=[x.clone(allow_gc=allow_gc) for x in self.linkers],
wrapper=self.wrapper,
)
def accept(self, fgraph, no_recycling=None, profile=None):
"""
Parameters
----------
fgraph : gof.FunctionGraph
The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers)
during the computation to avoid reusing it.
"""
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(self.linkers, self.wrapper).accept(fgraph, no_recycling)
self.fgraph = fgraph
self.no_recycling = no_recycling
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers]
return self
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs):
no_recycling = self.no_recycling
make_all = [self.linkers[0].make_all(**kwargs)]
kwargs.pop("input_storage", None)
make_all += [x.make_all(**kwargs) for x in self.linkers[1:]]
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*make_all)
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
if not order_list0 == order_list:
raise Exception(
"All linkers to WrapLinker should execute operations in the same order."
)
inputs0 = input_lists[0]
outputs0 = output_lists[0]
thunk_groups = list(zip(*thunk_lists))
order = [x[0] for x in zip(*order_lists)]
to_reset = []
for thunks, node in zip(thunk_groups, order):
for j, output in enumerate(node.outputs):
if output in no_recycling:
for thunk in thunks:
to_reset.append(thunk.outputs[j])
wrapper = self.wrapper
pre = self.pre
def f():
for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(self.fgraph, i, node, *thunks)
except Exception:
raise_with_op(self.fgraph, node, *thunks)
f.thunk_groups = thunk_groups
return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers):
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
"""
def wrapper(*args):
for f in wrappers:
f(*args)
return WrapLinker(linkers, wrapper)
......@@ -3,14 +3,14 @@ from warnings import warn
from theano.gof import utils
from theano.gof.graph import Constant
from theano.gof.link import (
from theano.link.basic import (
Container,
PerformLinker,
add_clear_storage,
gc_helper,
map_storage,
streamline,
)
from theano.link import Container
class JAXLinker(PerformLinker):
......
......@@ -53,7 +53,7 @@ from collections import OrderedDict
import numpy as np
import theano
from theano import compile, config, gof, gradient, tensor
from theano import compile, config, gof, gradient, link, tensor
from theano.compile.builders import infer_shape
from theano.compile.function import function
from theano.compile.io import In, Out
......@@ -1465,7 +1465,7 @@ class Scan(PureOp):
# done by raise_with_op is not implemented in C.
if hasattr(fn, "thunks"):
# For the CVM
gof.link.raise_with_op(
link.raise_with_op(
self.fn.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error],
......@@ -1475,7 +1475,7 @@ class Scan(PureOp):
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
gof.link.raise_with_op(
link.raise_with_op(
self.fn.maker.fgraph, fn.nodes[fn.position_of_error]
)
else:
......
......@@ -61,7 +61,7 @@ cimport numpy
import copy
import time
from theano import gof
from theano import gof, link
def get_version():
......@@ -405,7 +405,7 @@ def perform(
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
gof.link.raise_with_op(fn.maker.fgraph,
link.raise_with_op(fn.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
......@@ -413,7 +413,7 @@ def perform(
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
gof.link.raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
link.raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论