提交 5496618e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move memoize, uniq, difference, flatten, *_return_values to theano.utils

上级 64952355
...@@ -28,10 +28,11 @@ from theano.compile.function.types import ( ...@@ -28,10 +28,11 @@ from theano.compile.function.types import (
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard from theano.compile.ops import OutputGuard, _output_guard
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import graph, ops_with_inner_function, utils from theano.gof import graph, ops_with_inner_function
from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, LocalLinker from theano.link.basic import Container, LocalLinker
from theano.link.utils import map_storage, raise_with_op from theano.link.utils import map_storage, raise_with_op
from theano.utils import get_unbound_function from theano.utils import difference, get_unbound_function
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -1827,14 +1828,14 @@ class _Linker(LocalLinker): ...@@ -1827,14 +1828,14 @@ class _Linker(LocalLinker):
or debug or debug
or not isinstance(node.op, gof.op.COp) or not isinstance(node.op, gof.op.COp)
): ):
raise utils.MethodNotDefined() raise MethodNotDefined()
node.op.prepare_node(node, storage_map, compute_map, "c") node.op.prepare_node(node, storage_map, compute_map, "c")
thunk = node.op.make_c_thunk( thunk = node.op.make_c_thunk(
node, storage_map, compute_map, no_recycling node, storage_map, compute_map, no_recycling
) )
thunks_c.append(thunk) thunks_c.append(thunk)
except (NotImplementedError, utils.MethodNotDefined): except (NotImplementedError, MethodNotDefined):
thunks_c.append(None) thunks_c.append(None)
# Pure ops don't really have a perform ( or their perform just # Pure ops don't really have a perform ( or their perform just
...@@ -1885,7 +1886,7 @@ class _Linker(LocalLinker): ...@@ -1885,7 +1886,7 @@ class _Linker(LocalLinker):
# function's outputs. no_recycling_map will be used in f() below. # function's outputs. no_recycling_map will be used in f() below.
if self.no_recycling is True: if self.no_recycling is True:
no_recycling_map = list(storage_map.values()) no_recycling_map = list(storage_map.values())
no_recycling_map = utils.difference(no_recycling_map, input_storage) no_recycling_map = difference(no_recycling_map, input_storage)
else: else:
no_recycling_map = [ no_recycling_map = [
storage_map[r] for r in self.no_recycling if r not in fgraph.inputs storage_map[r] for r in self.no_recycling if r not in fgraph.inputs
...@@ -2003,7 +2004,7 @@ class _Linker(LocalLinker): ...@@ -2003,7 +2004,7 @@ class _Linker(LocalLinker):
) )
try: try:
thunk_py() thunk_py()
except (utils.MethodNotDefined, NotImplementedError): except (MethodNotDefined, NotImplementedError):
# shouldn't have put it into the list in # shouldn't have put it into the list in
# the first place # the first place
thunk_py = None thunk_py = None
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import theano import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.utils import flatten from theano.utils import flatten
_logger = logging.getLogger("theano.gof.compiledir") _logger = logging.getLogger("theano.gof.compiledir")
......
...@@ -26,8 +26,9 @@ from theano.gof import graph ...@@ -26,8 +26,9 @@ from theano.gof import graph
from theano.gof.fg import InconsistencyError from theano.gof.fg import InconsistencyError
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.toolbox import Feature, NodeFinder from theano.gof.toolbox import Feature, NodeFinder
from theano.gof.utils import AssocList, flatten from theano.gof.utils import AssocList
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
from theano.utils import flatten
_logger = logging.getLogger("theano.gof.opt") _logger = logging.getLogger("theano.gof.opt")
......
...@@ -344,70 +344,6 @@ class AssocList: ...@@ -344,70 +344,6 @@ class AssocList:
return f"AssocList({self._dict}, {self._list})" return f"AssocList({self._dict}, {self._list})"
def memoize(f):
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def uniq(seq):
"""
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
we must keep the same order.
"""
# TODO: consider building a set out of seq so that the if condition
# is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception("not worth it")
set2 = set(seq2)
return [x for x in seq1 if x not in set2]
except Exception:
# maybe a seq2 element is not hashable
# maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def to_return_values(values):
if len(values) == 1:
return values[0]
else:
return values
def from_return_values(values):
if isinstance(values, (list, tuple)):
return values
else:
return [values]
def toposort(prereqs_d): def toposort(prereqs_d):
""" """
Sorts prereqs_d.keys() topologically. Sorts prereqs_d.keys() topologically.
...@@ -447,17 +383,3 @@ def toposort(prereqs_d): ...@@ -447,17 +383,3 @@ def toposort(prereqs_d):
"some orderings contain invalid elements." "some orderings contain invalid elements."
) )
return seq return seq
def flatten(a):
"""
Recursively flatten tuple, list and set in a list.
"""
if isinstance(a, (tuple, list, set)):
l = []
for item in a:
l.extend(flatten(item))
return l
else:
return [a]
import typing import typing
from copy import copy, deepcopy from copy import copy, deepcopy
from theano import utils
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply from theano.gof.graph import Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.utils import to_return_values
from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
from theano.utils import deprecated, difference, to_return_values
class Container: class Container:
...@@ -191,7 +190,7 @@ class Linker: ...@@ -191,7 +190,7 @@ class Linker:
f"make_thunk method of {type(self)} is not implemented." f"make_thunk method of {type(self)} is not implemented."
) )
@utils.deprecated("Marked for deletion. Only tests use it.") @deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single=True, **kwargs): def make_function(self, unpack_single=True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
...@@ -387,7 +386,7 @@ class PerformLinker(LocalLinker): ...@@ -387,7 +386,7 @@ class PerformLinker(LocalLinker):
# True seems like some special code for *everything*?? -JB # True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB # FunctionMaker always passes a list I think -JB
no_recycling = list(storage_map.values()) no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage) no_recycling = difference(no_recycling, input_storage)
else: else:
no_recycling = [ no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs storage_map[r] for r in no_recycling if r not in fgraph.inputs
......
...@@ -16,7 +16,7 @@ from theano.gof.callcache import CallCache ...@@ -16,7 +16,7 @@ from theano.gof.callcache import CallCache
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
from theano.gof.graph import Constant, NoParams, io_toposort from theano.gof.graph import Constant, NoParams, io_toposort
from theano.gof.graph import variables as get_variables from theano.gof.graph import variables as get_variables
from theano.gof.utils import MethodNotDefined, difference, uniq from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, Linker, LocalLinker, PerformLinker from theano.link.basic import Container, Linker, LocalLinker, PerformLinker
from theano.link.c.cmodule import ( from theano.link.c.cmodule import (
METH_VARARGS, METH_VARARGS,
...@@ -27,6 +27,7 @@ from theano.link.c.cmodule import ( ...@@ -27,6 +27,7 @@ from theano.link.c.cmodule import (
) )
from theano.link.c.cmodule import get_module_cache as _get_module_cache from theano.link.c.cmodule import get_module_cache as _get_module_cache
from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
from theano.utils import difference, uniq
_logger = logging.getLogger("theano.link.c.basic") _logger = logging.getLogger("theano.link.c.basic")
......
...@@ -26,10 +26,10 @@ from theano.configdefaults import config, gcc_version_str ...@@ -26,10 +26,10 @@ from theano.configdefaults import config, gcc_version_str
# we will abuse the lockfile mechanism when reading and writing the registry # we will abuse the lockfile mechanism when reading and writing the registry
from theano.gof import compilelock from theano.gof import compilelock
from theano.gof.utils import flatten
from theano.link.c.exceptions import MissingGXX from theano.link.c.exceptions import MissingGXX
from theano.utils import ( from theano.utils import (
LOCAL_BITWIDTH, LOCAL_BITWIDTH,
flatten,
hash_from_code, hash_from_code,
output_subprocess_Popen, output_subprocess_Popen,
subprocess_Popen, subprocess_Popen,
......
from collections.abc import Sequence from collections.abc import Sequence
from warnings import warn from warnings import warn
from theano.gof import utils
from theano.gof.graph import Constant from theano.gof.graph import Constant
from theano.link.basic import Container, PerformLinker from theano.link.basic import Container, PerformLinker
from theano.link.utils import gc_helper, map_storage, streamline from theano.link.utils import gc_helper, map_storage, streamline
from theano.utils import difference
class JAXLinker(PerformLinker): class JAXLinker(PerformLinker):
...@@ -167,7 +167,7 @@ class JAXLinker(PerformLinker): ...@@ -167,7 +167,7 @@ class JAXLinker(PerformLinker):
if no_recycling is True: if no_recycling is True:
no_recycling = list(storage_map.values()) no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage) no_recycling = difference(no_recycling, input_storage)
else: else:
no_recycling = [ no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs storage_map[r] for r in no_recycling if r not in fgraph.inputs
......
...@@ -27,16 +27,11 @@ from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes ...@@ -27,16 +27,11 @@ from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.opt import MergeOptimizer from theano.gof.opt import MergeOptimizer
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.utils import ( from theano.gof.utils import MetaObject, MethodNotDefined
MetaObject,
MethodNotDefined,
difference,
from_return_values,
to_return_values,
)
from theano.gradient import DisconnectedType, grad_undefined from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
from theano.utils import difference, from_return_values, to_return_values
builtin_bool = bool builtin_bool = bool
......
...@@ -160,13 +160,14 @@ from theano.gof.opt import ( ...@@ -160,13 +160,14 @@ from theano.gof.opt import (
) )
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.toolbox import ReplaceValidate from theano.gof.toolbox import ReplaceValidate
from theano.gof.utils import MethodNotDefined, TestValueError, memoize from theano.gof.utils import MethodNotDefined, TestValueError
from theano.printing import FunctionPrinter, debugprint, pprint from theano.printing import FunctionPrinter, debugprint, pprint
from theano.scalar import bool as bool_t from theano.scalar import bool as bool_t
from theano.tensor import basic as tt from theano.tensor import basic as tt
from theano.tensor.blas_headers import blas_header_text, blas_header_version from theano.tensor.blas_headers import blas_header_text, blas_header_version
from theano.tensor.opt import in2out, local_dimshuffle_lift from theano.tensor.opt import in2out, local_dimshuffle_lift
from theano.tensor.type import values_eq_approx_remove_inf_nan from theano.tensor.type import values_eq_approx_remove_inf_nan
from theano.utils import memoize
_logger = logging.getLogger("theano.tensor.blas") _logger = logging.getLogger("theano.tensor.blas")
......
...@@ -3,7 +3,7 @@ from copy import copy ...@@ -3,7 +3,7 @@ from copy import copy
import numpy as np import numpy as np
import theano import theano
from theano import gof, scalar from theano import scalar
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import ParamsType from theano.gof import ParamsType
from theano.gof.graph import Apply from theano.gof.graph import Apply
...@@ -15,6 +15,7 @@ from theano.misc.safe_asarray import _asarray ...@@ -15,6 +15,7 @@ from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
from theano.scalar import get_scalar_type from theano.scalar import get_scalar_type
from theano.tensor import elemwise_cgen as cgen from theano.tensor import elemwise_cgen as cgen
from theano.utils import uniq
_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] _numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
...@@ -899,12 +900,12 @@ second dimension ...@@ -899,12 +900,12 @@ second dimension
_inames = inames _inames = inames
_onames = onames _onames = onames
inames = gof.utils.uniq(inames) inames = uniq(inames)
inputs = gof.utils.uniq(node.inputs) inputs = uniq(node.inputs)
# assert that inames and inputs order stay consistent. # assert that inames and inputs order stay consistent.
# This is to protect again futur change of uniq. # This is to protect again futur change of uniq.
assert len(inames) == len(inputs) assert len(inames) == len(inputs)
ii, iii = list(zip(*gof.utils.uniq(list(zip(_inames, node.inputs))))) ii, iii = list(zip(*uniq(list(zip(_inames, node.inputs)))))
assert all([x == y for x, y in zip(ii, inames)]) assert all([x == y for x, y in zip(ii, inames)])
assert all([x == y for x, y in zip(iii, inputs)]) assert all([x == y for x, y in zip(iii, inputs)])
......
...@@ -296,3 +296,81 @@ def hash_from_code(msg): ...@@ -296,3 +296,81 @@ def hash_from_code(msg):
# Python 3 does not like module names that start with # Python 3 does not like module names that start with
# a digit. # a digit.
return "m" + hashlib.sha256(msg).hexdigest() return "m" + hashlib.sha256(msg).hexdigest()
def memoize(f):
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def uniq(seq):
"""
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
we must keep the same order.
"""
# TODO: consider building a set out of seq so that the if condition
# is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception("not worth it")
set2 = set(seq2)
return [x for x in seq1 if x not in set2]
except Exception:
# maybe a seq2 element is not hashable
# maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def to_return_values(values):
if len(values) == 1:
return values[0]
else:
return values
def from_return_values(values):
if isinstance(values, (list, tuple)):
return values
else:
return [values]
def flatten(a):
"""
Recursively flatten tuple, list and set in a list.
"""
if isinstance(a, (tuple, list, set)):
l = []
for item in a:
l.extend(flatten(item))
return l
else:
return [a]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论