提交 d24c03a0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

got rid of the whitelisting, it was making optimizations fail

上级 a702ad7f
......@@ -960,7 +960,7 @@ class DestroyHandler(toolbox.Bookkeeper):
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list)
tolerated = OrderedSet((idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx))
if idx0 == destroyed_idx), known_deterministc=True)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
......
......@@ -7,9 +7,17 @@ except ImportError:
from theano.gof.python25 import OrderedDict
import types
def check_deterministic(iterable):
if not isinstance(iterable, (list, tuple, OrderedSet, types.GeneratorType)):
raise TypeError(type(iterable))
def check_deterministic(iterable, known_deterministic):
# Most places where OrderedSet is used, theano interprets any exception
# whatsoever as a problem that an optimization introduced into the graph.
# If I raise a TypeError when the DestoryHandler tries to do something
# non-deterministic, it will just result in optimizations getting ignored.
# So I must use an assert here. In the long term we should fix the rest of
# theano to use exceptions correctly, so that this can be a TypeError.
if iterable is not None:
assert isinstance(iterable, (list, tuple, OrderedSet, types.GeneratorType))
if isinstance(iterable, types.GeneratorType):
assert known_deterministic
if MutableSet is not None:
# From http://code.activestate.com/recipes/576694/
......@@ -39,13 +47,13 @@ if MutableSet is not None:
# Added by IG-- pre-existing theano code expected sets
# to have this method
def update(self, iterable):
check_deterministic(iterable)
def update(self, iterable, known_deterministic = False):
check_deterministic(iterable, known_deterministic)
self |= iterable
def __init__(self, iterable=None):
def __init__(self, iterable=None, known_deterministic=False):
# Checks added by IG
check_deterministic(iterable)
check_deterministic(iterable, known_deterministic)
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
......@@ -118,7 +126,7 @@ else:
if iterable is not None:
self.update(iterable)
def update(self, container):
def update(self, container, known_deterministic=False):
check_deterministic(container)
for elem in container:
self.add(elem)
......
......@@ -72,3 +72,6 @@ def test_determinism_1():
# (Sometimes you sample the same outcome twice in a row)
for i in xrange(10):
run(1, log)
if __name__ == '__main__':
test_determinism_1()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论