提交 cf655c6f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

make determinism type checks work with generator expressions

上级 8015719a
......@@ -958,8 +958,9 @@ class DestroyHandler(toolbox.Bookkeeper):
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
assert isinstance(tolerate_same, list)
tolerated = OrderedSet((idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx), known_deterministic=True)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
......
......@@ -5,6 +5,12 @@ except ImportError:
# Python 2.4
pass
from theano.gof.python25 import OrderedDict
import types
def check_deterministic(iterable, known_deterministic):
if not isinstance(iterable, (list, tuple, OrderedSet)) or \
(isinstance(iterable, types.GeneratorType and known_deterministic)):
raise TypeError((type(iterable), known_deterministic))
if MutableSet is not None:
# From http://code.activestate.com/recipes/576694/
......@@ -14,13 +20,18 @@ if MutableSet is not None:
# Added by IG-- pre-existing theano code expected sets
# to have this method
def update(self, container):
# only allowed ordered containers
assert isinstance(container, (list, OrderedSet))
for elem in container:
self.add(elem)
def update(self, iterable, known_deterministic=False):
check_deterministic(iterable, known_deterministic)
self |= iterable
def __init__(self, iterable=None):
def __init__(self, iterable=None, known_deterministic=False):
"""
known_deterministic: if iterable is a generator expression,
the caller must certify that it is from
a deterministic class, like list or OrderedSet
"""
# Checks added by IG
check_deterministic(iterable, known_deterministic)
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
......@@ -88,16 +99,13 @@ else:
An implementation of OrderedSet based on the keys of
an OrderedDict.
"""
def __init__(self, iterable=None):
def __init__(self, iterable=None, known_deterministic=False):
self.data = OrderedDict()
if iterable is not None:
self.update(iterable)
self.update(iterable, known_deterministic)
def update(self, container):
# only allowed ordered containers
if not isinstance(container, (list, tuple, OrderedSet)):
raise TypeError("OrderedSet can only be ordered if updated "
" with ordered containers. Got "+str(type(container)))
def update(self, container, known_deterministic=False):
check_deterministic(container, known_deterministic)
for elem in container:
self.add(elem)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论