提交 86e07952 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

whitelist all generators

上级 87c1d28c
...@@ -960,7 +960,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -960,7 +960,7 @@ class DestroyHandler(toolbox.Bookkeeper):
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = OrderedSet((idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet((idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx), known_deterministic=True) if idx0 == destroyed_idx))
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
......
...@@ -7,11 +7,9 @@ except ImportError: ...@@ -7,11 +7,9 @@ except ImportError:
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
import types import types
def check_deterministic(iterable, known_deterministic): def check_deterministic(iterable):
if not isinstance(iterable, (list, tuple, OrderedSet)): if not isinstance(iterable, (list, tuple, OrderedSet, types.GeneratorType)):
if not isinstance(iterable, types.GeneratorType) and \ raise TypeError(type(iterable))
known_deterministic:
raise TypeError((type(iterable), known_deterministic))
if MutableSet is not None: if MutableSet is not None:
# From http://code.activestate.com/recipes/576694/ # From http://code.activestate.com/recipes/576694/
...@@ -41,18 +39,13 @@ if MutableSet is not None: ...@@ -41,18 +39,13 @@ if MutableSet is not None:
# Added by IG-- pre-existing theano code expected sets # Added by IG-- pre-existing theano code expected sets
# to have this method # to have this method
def update(self, iterable, known_deterministic=False): def update(self, iterable):
check_deterministic(iterable, known_deterministic) check_deterministic(iterable)
self |= iterable self |= iterable
def __init__(self, iterable=None, known_deterministic=False): def __init__(self, iterable=None):
"""
known_deterministic: if iterable is a generator expression,
the caller must certify that it is from
a deterministic class, like list or OrderedSet
"""
# Checks added by IG # Checks added by IG
check_deterministic(iterable, known_deterministic) check_deterministic(iterable)
self.end = end = [] self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next] self.map = {} # key --> [key, prev, next]
...@@ -120,13 +113,13 @@ else: ...@@ -120,13 +113,13 @@ else:
An implementation of OrderedSet based on the keys of An implementation of OrderedSet based on the keys of
an OrderedDict. an OrderedDict.
""" """
def __init__(self, iterable=None, known_deterministic=False): def __init__(self, iterable=None):
self.data = OrderedDict() self.data = OrderedDict()
if iterable is not None: if iterable is not None:
self.update(iterable, known_deterministic) self.update(iterable)
def update(self, container, known_deterministic=False): def update(self, container):
check_deterministic(container, known_deterministic) check_deterministic(container)
for elem in container: for elem in container:
self.add(elem) self.add(elem)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论