提交 5496618e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move memoize, uniq, difference, flatten, *_return_values to theano.utils

上级 64952355
......@@ -28,10 +28,11 @@ from theano.compile.function.types import (
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard
from theano.configdefaults import config
from theano.gof import graph, ops_with_inner_function, utils
from theano.gof import graph, ops_with_inner_function
from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, LocalLinker
from theano.link.utils import map_storage, raise_with_op
from theano.utils import get_unbound_function
from theano.utils import difference, get_unbound_function
__docformat__ = "restructuredtext en"
......@@ -1827,14 +1828,14 @@ class _Linker(LocalLinker):
or debug
or not isinstance(node.op, gof.op.COp)
):
raise utils.MethodNotDefined()
raise MethodNotDefined()
node.op.prepare_node(node, storage_map, compute_map, "c")
thunk = node.op.make_c_thunk(
node, storage_map, compute_map, no_recycling
)
thunks_c.append(thunk)
except (NotImplementedError, utils.MethodNotDefined):
except (NotImplementedError, MethodNotDefined):
thunks_c.append(None)
# Pure ops don't really have a perform ( or their perform just
......@@ -1885,7 +1886,7 @@ class _Linker(LocalLinker):
# function's outputs. no_recycling_map will be used in f() below.
if self.no_recycling is True:
no_recycling_map = list(storage_map.values())
no_recycling_map = utils.difference(no_recycling_map, input_storage)
no_recycling_map = difference(no_recycling_map, input_storage)
else:
no_recycling_map = [
storage_map[r] for r in self.no_recycling if r not in fgraph.inputs
......@@ -2003,7 +2004,7 @@ class _Linker(LocalLinker):
)
try:
thunk_py()
except (utils.MethodNotDefined, NotImplementedError):
except (MethodNotDefined, NotImplementedError):
# shouldn't have put it into the list in
# the first place
thunk_py = None
......
......@@ -7,7 +7,7 @@ import numpy as np
import theano
from theano.configdefaults import config
from theano.gof.utils import flatten
from theano.utils import flatten
_logger = logging.getLogger("theano.gof.compiledir")
......
......@@ -26,8 +26,9 @@ from theano.gof import graph
from theano.gof.fg import InconsistencyError
from theano.gof.op import Op
from theano.gof.toolbox import Feature, NodeFinder
from theano.gof.utils import AssocList, flatten
from theano.gof.utils import AssocList
from theano.misc.ordered_set import OrderedSet
from theano.utils import flatten
_logger = logging.getLogger("theano.gof.opt")
......
......@@ -344,70 +344,6 @@ class AssocList:
return f"AssocList({self._dict}, {self._list})"
def memoize(f):
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def uniq(seq):
"""
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
we must keep the same order.
"""
# TODO: consider building a set out of seq so that the if condition
# is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception("not worth it")
set2 = set(seq2)
return [x for x in seq1 if x not in set2]
except Exception:
# maybe a seq2 element is not hashable
# maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def to_return_values(values):
if len(values) == 1:
return values[0]
else:
return values
def from_return_values(values):
if isinstance(values, (list, tuple)):
return values
else:
return [values]
def toposort(prereqs_d):
"""
Sorts prereqs_d.keys() topologically.
......@@ -447,17 +383,3 @@ def toposort(prereqs_d):
"some orderings contain invalid elements."
)
return seq
def flatten(a):
"""
Recursively flatten tuple, list and set in a list.
"""
if isinstance(a, (tuple, list, set)):
l = []
for item in a:
l.extend(flatten(item))
return l
else:
return [a]
import typing
from copy import copy, deepcopy
from theano import utils
from theano.configdefaults import config
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply
from theano.gof.type import Type
from theano.gof.utils import to_return_values
from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
from theano.utils import deprecated, difference, to_return_values
class Container:
......@@ -191,7 +190,7 @@ class Linker:
f"make_thunk method of {type(self)} is not implemented."
)
@utils.deprecated("Marked for deletion. Only tests use it.")
@deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
......@@ -387,7 +386,7 @@ class PerformLinker(LocalLinker):
# True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
......
......@@ -16,7 +16,7 @@ from theano.gof.callcache import CallCache
from theano.gof.compilelock import get_lock, release_lock
from theano.gof.graph import Constant, NoParams, io_toposort
from theano.gof.graph import variables as get_variables
from theano.gof.utils import MethodNotDefined, difference, uniq
from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, Linker, LocalLinker, PerformLinker
from theano.link.c.cmodule import (
METH_VARARGS,
......@@ -27,6 +27,7 @@ from theano.link.c.cmodule import (
)
from theano.link.c.cmodule import get_module_cache as _get_module_cache
from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
from theano.utils import difference, uniq
_logger = logging.getLogger("theano.link.c.basic")
......
......@@ -26,10 +26,10 @@ from theano.configdefaults import config, gcc_version_str
# we will abuse the lockfile mechanism when reading and writing the registry
from theano.gof import compilelock
from theano.gof.utils import flatten
from theano.link.c.exceptions import MissingGXX
from theano.utils import (
LOCAL_BITWIDTH,
flatten,
hash_from_code,
output_subprocess_Popen,
subprocess_Popen,
......
from collections.abc import Sequence
from warnings import warn
from theano.gof import utils
from theano.gof.graph import Constant
from theano.link.basic import Container, PerformLinker
from theano.link.utils import gc_helper, map_storage, streamline
from theano.utils import difference
class JAXLinker(PerformLinker):
......@@ -167,7 +167,7 @@ class JAXLinker(PerformLinker):
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
......
......@@ -27,16 +27,11 @@ from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes
from theano.gof.op import COp
from theano.gof.opt import MergeOptimizer
from theano.gof.type import Type
from theano.gof.utils import (
MetaObject,
MethodNotDefined,
difference,
from_return_values,
to_return_values,
)
from theano.gof.utils import MetaObject, MethodNotDefined
from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray
from theano.printing import pprint
from theano.utils import difference, from_return_values, to_return_values
builtin_bool = bool
......
......@@ -160,13 +160,14 @@ from theano.gof.opt import (
)
from theano.gof.params_type import ParamsType
from theano.gof.toolbox import ReplaceValidate
from theano.gof.utils import MethodNotDefined, TestValueError, memoize
from theano.gof.utils import MethodNotDefined, TestValueError
from theano.printing import FunctionPrinter, debugprint, pprint
from theano.scalar import bool as bool_t
from theano.tensor import basic as tt
from theano.tensor.blas_headers import blas_header_text, blas_header_version
from theano.tensor.opt import in2out, local_dimshuffle_lift
from theano.tensor.type import values_eq_approx_remove_inf_nan
from theano.utils import memoize
_logger = logging.getLogger("theano.tensor.blas")
......
......@@ -3,7 +3,7 @@ from copy import copy
import numpy as np
import theano
from theano import gof, scalar
from theano import scalar
from theano.configdefaults import config
from theano.gof import ParamsType
from theano.gof.graph import Apply
......@@ -15,6 +15,7 @@ from theano.misc.safe_asarray import _asarray
from theano.printing import pprint
from theano.scalar import get_scalar_type
from theano.tensor import elemwise_cgen as cgen
from theano.utils import uniq
_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
......@@ -899,12 +900,12 @@ second dimension
_inames = inames
_onames = onames
inames = gof.utils.uniq(inames)
inputs = gof.utils.uniq(node.inputs)
inames = uniq(inames)
inputs = uniq(node.inputs)
# assert that inames and inputs order stay consistent.
# This is to protect again futur change of uniq.
assert len(inames) == len(inputs)
ii, iii = list(zip(*gof.utils.uniq(list(zip(_inames, node.inputs)))))
ii, iii = list(zip(*uniq(list(zip(_inames, node.inputs)))))
assert all([x == y for x, y in zip(ii, inames)])
assert all([x == y for x, y in zip(iii, inputs)])
......
......@@ -296,3 +296,81 @@ def hash_from_code(msg):
# Python 3 does not like module names that start with
# a digit.
return "m" + hashlib.sha256(msg).hexdigest()
def memoize(f):
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
if key not in cache:
val = f(*args, **kwargs)
cache[key] = val
else:
val = cache[key]
return val
return rval
def uniq(seq):
"""
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
we must keep the same order.
"""
# TODO: consider building a set out of seq so that the if condition
# is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception("not worth it")
set2 = set(seq2)
return [x for x in seq1 if x not in set2]
except Exception:
# maybe a seq2 element is not hashable
# maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def to_return_values(values):
if len(values) == 1:
return values[0]
else:
return values
def from_return_values(values):
if isinstance(values, (list, tuple)):
return values
else:
return [values]
def flatten(a):
"""
Recursively flatten tuple, list and set in a list.
"""
if isinstance(a, (tuple, list, set)):
l = []
for item in a:
l.extend(flatten(item))
return l
else:
return [a]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论