提交 d0e4dfc1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Rename _metadict to AssocList and move it to theano.gof.utils

上级 9bbaf30a
......@@ -21,8 +21,9 @@ import numpy as np
import theano
from theano import config
from theano.gof import graph, op, toolbox, unify, utils
from theano.gof import graph, op, toolbox, unify
from theano.gof.fg import InconsistencyError
from theano.gof.utils import AssocList, flatten
from theano.misc.ordered_set import OrderedSet
from . import destroyhandler as dh
......@@ -95,7 +96,6 @@ class GlobalOptimizer(abc.ABC):
etc.
"""
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None)
......@@ -451,80 +451,6 @@ class SeqOptimizer(GlobalOptimizer, UserList):
)
class _metadict:
"""
WRITEME
"""
# dict that accepts unhashable keys
# uses an associative list
# for internal use only
def __init__(self):
self._dict = {}
self._list = []
def __getitem__(self, item):
return self.get(item, None)
def __setitem__(self, item, value):
try:
self._dict[item] = value
except Exception:
for i, (key, val) in enumerate(self._list):
if key == item:
self._list[i] = (item, value)
return
self._list.append((item, value))
def __delitem__(self, item):
try:
if item in self._dict:
del self._dict[item]
return
except TypeError as e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self._list):
if key == item:
del self._list[i]
return
raise KeyError(item)
def discard(self, item):
try:
if item in self._dict:
del self._dict[item]
return
except TypeError as e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self._list):
if key == item:
del self._list[i]
return
def get(self, item, default):
try:
return self._dict[item]
except Exception:
for item2, value in self._list:
try:
if item == item2:
return value
if item.equals(item2):
return value
except Exception:
if item is item2:
return value
return default
def clear(self):
self._dict = {}
self._list = []
def __str__(self):
return f"({self._dict}, {self._list})"
class MergeFeature:
"""
Keeps track of variables in fgraph that cannot be merged together.
......@@ -541,9 +467,9 @@ class MergeFeature:
# For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = _metadict()
self.const_sig = AssocList()
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
self.const_sig_inv = AssocList()
# For all Apply nodes
# Set of distinct (not mergeable) nodes
......@@ -906,7 +832,7 @@ class MergeOptimizer(GlobalOptimizer):
if (
sum(
[
i in utils.flatten(c.op.destroy_map.values())
i in flatten(c.op.destroy_map.values())
for c, i in clients
if c != "output" and hasattr(c.op, "destroy_map")
]
......@@ -1143,7 +1069,6 @@ class LocalOptimizer(abc.ABC):
fgraph, this is the place to do it.
"""
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class_.__name__} id={id(self)}", file=stream)
......
......@@ -284,6 +284,78 @@ class D:
self.__dict__.update(d)
class AssocList:
"""An associative list.
This class is like a `dict` that accepts unhashable keys by using an
assoc list for internal use only
"""
def __init__(self):
self._dict = {}
self._list = []
def __getitem__(self, item):
return self.get(item, None)
def __setitem__(self, item, value):
try:
self._dict[item] = value
except Exception:
for i, (key, val) in enumerate(self._list):
if key == item:
self._list[i] = (item, value)
return
self._list.append((item, value))
def __delitem__(self, item):
try:
if item in self._dict:
del self._dict[item]
return
except TypeError as e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self._list):
if key == item:
del self._list[i]
return
raise KeyError(item)
def discard(self, item):
try:
if item in self._dict:
del self._dict[item]
return
except TypeError as e:
assert "unhashable type" in str(e)
for i, (key, val) in enumerate(self._list):
if key == item:
del self._list[i]
return
def get(self, item, default):
try:
return self._dict[item]
except Exception:
for item2, value in self._list:
try:
if item == item2:
return value
if item.equals(item2):
return value
except Exception:
if item is item2:
return value
return default
def clear(self):
self._dict = {}
self._list = []
def __repr__(self):
return f"AssocList({self._dict}, {self._list})"
def memoize(f):
"""
Cache the return value for each tuple of arguments (which must be hashable).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论