提交 5cc5290d authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Merge linker utility functions, remove add_clear_storage

上级 71f727da
...@@ -29,7 +29,7 @@ from theano.compile.mode import Mode, register_mode ...@@ -29,7 +29,7 @@ from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard from theano.compile.ops import OutputGuard, _output_guard
from theano.gof import graph, ops_with_inner_function, utils from theano.gof import graph, ops_with_inner_function, utils
from theano.link.basic import LocalLinker from theano.link.basic import LocalLinker
from theano.link.debugging import raise_with_op from theano.link.utils import raise_with_op
from theano.utils import get_unbound_function from theano.utils import get_unbound_function
......
...@@ -5,11 +5,5 @@ ...@@ -5,11 +5,5 @@
PerformLinker, PerformLinker,
WrapLinker, WrapLinker,
WrapLinkerMany, WrapLinkerMany,
gc_helper,
map_storage,
streamline,
) )
from theano.link.debugging import raise_with_op, register_thunk_trace_excepthook from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
register_thunk_trace_excepthook()
...@@ -3,10 +3,10 @@ from copy import copy, deepcopy ...@@ -3,10 +3,10 @@ from copy import copy, deepcopy
from theano import config, utils from theano import config, utils
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant from theano.gof.graph import Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.utils import to_return_values from theano.gof.utils import to_return_values
from theano.link.debugging import raise_with_op from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
class Container: class Container:
...@@ -281,247 +281,6 @@ class LocalLinker(Linker): ...@@ -281,247 +281,6 @@ class LocalLinker(Linker):
) )
def map_storage(
fgraph: FunctionGraph,
order: typing.Iterable[Apply],
input_storage: typing.Optional[typing.List],
output_storage: typing.Optional[typing.List],
storage_map: typing.Dict = None,
) -> typing.Tuple[typing.List, typing.List, typing.Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs.
if input_storage is None:
input_storage = [[None] for input in fgraph.inputs]
else:
assert len(fgraph.inputs) == len(input_storage)
# add input storage into storage_map
for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage,
"Storage in storage_ma" "p: ",
storage_map[r],
)
else:
storage_map[r] = storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage)
for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ",
storage,
"Sto" "rage in storage_map: ",
storage_map[r],
)
else:
storage_map[r] = storage
# allocate storage for intermediate computation
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, Constant)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
for r in fgraph.outputs:
if isinstance(r, Constant):
storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs]
return input_storage, output_storage, storage_map
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
storage_map[c][0] = None
f.clear_storage = clear_storage
def streamline(
fgraph: FunctionGraph,
thunks,
order,
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
) -> typing.Callable[[], typing.NoReturn]:
"""
WRITEME
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if no_recycling is None:
no_recycling = []
if len(thunks) != len(order):
raise ValueError(
"Length of thunks and order must match", (len(thunks), len(order))
)
if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage):
raise ValueError(
"Length of thunks and post_thunk_old_storage must match",
(len(thunks), len(post_thunk_old_storage)),
)
def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage
):
thunk()
for old_s in old_storage:
old_s[0] = None
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_default_f
elif nice_errors:
def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_nice_errors_f
else:
# don't worry about raise_with_op, just go a little faster.
# there is a mix of python and c thunks
def streamline_fast_f():
for x in no_recycling:
x[0] = None
for thunk in thunks:
thunk()
f = streamline_fast_f
return f
def gc_helper(node_list: typing.List[Apply]):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
"""
# for freeing memory
last_user = {}
computed = set()
for node in node_list:
for input in node.inputs:
last_user[input] = node
for output in node.outputs:
computed.add(output)
return computed, last_user
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
""" """
Basic L{Linker} subclass that calls the perform method on each L{Op} in Basic L{Linker} subclass that calls the perform method on each L{Op} in
...@@ -641,7 +400,6 @@ class PerformLinker(LocalLinker): ...@@ -641,7 +400,6 @@ class PerformLinker(LocalLinker):
f.allow_gc = ( f.allow_gc = (
self.allow_gc self.allow_gc
) # HACK: this is a way of passing an arg to Function.__call__ ) # HACK: this is a way of passing an arg to Function.__call__
add_clear_storage(f, computed, storage_map)
f.storage_map = storage_map f.storage_map = storage_map
return ( return (
......
...@@ -3,14 +3,7 @@ from warnings import warn ...@@ -3,14 +3,7 @@ from warnings import warn
from theano.gof import utils from theano.gof import utils
from theano.gof.graph import Constant from theano.gof.graph import Constant
from theano.link.basic import ( from theano.link import Container, PerformLinker, gc_helper, map_storage, streamline
Container,
PerformLinker,
add_clear_storage,
gc_helper,
map_storage,
streamline,
)
class JAXLinker(PerformLinker): class JAXLinker(PerformLinker):
...@@ -184,7 +177,6 @@ class JAXLinker(PerformLinker): ...@@ -184,7 +177,6 @@ class JAXLinker(PerformLinker):
) )
fn.allow_gc = self.allow_gc fn.allow_gc = self.allow_gc
add_clear_storage(fn, computed, storage_map)
fn.storage_map = storage_map fn.storage_map = storage_map
return ( return (
......
import io import io
import sys import sys
import traceback import traceback
import typing
import warnings import warnings
from operator import itemgetter from operator import itemgetter
...@@ -8,6 +9,240 @@ import numpy as np ...@@ -8,6 +9,240 @@ import numpy as np
from theano import config, utils from theano import config, utils
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant
def map_storage(
fgraph: FunctionGraph,
order: typing.Iterable[Apply],
input_storage: typing.Optional[typing.List],
output_storage: typing.Optional[typing.List],
storage_map: typing.Dict = None,
) -> typing.Tuple[typing.List, typing.List, typing.Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs.
if input_storage is None:
input_storage = [[None] for input in fgraph.inputs]
else:
assert len(fgraph.inputs) == len(input_storage)
# add input storage into storage_map
for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage,
"Storage in storage_ma" "p: ",
storage_map[r],
)
else:
storage_map[r] = storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage)
for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ",
storage,
"Sto" "rage in storage_map: ",
storage_map[r],
)
else:
storage_map[r] = storage
# allocate storage for intermediate computation
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, Constant)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
for r in fgraph.outputs:
if isinstance(r, Constant):
storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs]
return input_storage, output_storage, storage_map
def streamline(
fgraph: FunctionGraph,
thunks,
order,
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
) -> typing.Callable[[], typing.NoReturn]:
"""
WRITEME
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if no_recycling is None:
no_recycling = []
if len(thunks) != len(order):
raise ValueError(
"Length of thunks and order must match", (len(thunks), len(order))
)
if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage):
raise ValueError(
"Length of thunks and post_thunk_old_storage must match",
(len(thunks), len(post_thunk_old_storage)),
)
def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage
):
thunk()
for old_s in old_storage:
old_s[0] = None
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_default_f
elif nice_errors:
def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
f = streamline_nice_errors_f
else:
# don't worry about raise_with_op, just go a little faster.
# there is a mix of python and c thunks
def streamline_fast_f():
for x in no_recycling:
x[0] = None
for thunk in thunks:
thunk()
f = streamline_fast_f
return f
def gc_helper(node_list: typing.List[Apply]):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
"""
# for freeing memory
last_user = {}
computed = set()
for node in node_list:
for input in node.inputs:
last_user[input] = node
for output in node.outputs:
computed.add(output)
return computed, last_user
def raise_with_op( def raise_with_op(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论