提交 b1ff1b33 authored 作者: Frederic's avatar Frederic

The Updates object now check that the key are SharedVariable when we pass them…

The Updates object now check that the key are SharedVariable when we pass them in the __init__ function.
上级 c3f6b9fb
import unittest
import theano import theano
from theano.updates import Updates from theano.updates import Updates
import theano.tensor as T import theano.tensor as T
def test_updates_setitem(): class test_ifelse(unittest.TestCase):
ok = True
up = Updates() def test_updates_init(self):
sv = theano.shared('asdf') self.assertRaises(TypeError, Updates, dict(d=3))
# keys have to be SharedVariables sv = theano.shared('asdf')
try: Updates({sv:3})
up[5] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables def test_updates_setitem(self):
try:
up[T.vector()] = 7
ok = False
except TypeError:
ok = True ok = True
assert ok
# keys have to be SharedVariables
up[theano.shared(88)] = 7
def test_updates_add(): up = Updates()
sv = theano.shared('asdf')
up1 = Updates() # keys have to be SharedVariables
up2 = Updates() self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
a = theano.shared('a') up[theano.shared(88)] = 7
b = theano.shared('b')
def test_updates_add(self):
assert not up1 + up2 up1 = Updates()
up2 = Updates()
up1[a] = 5 a = theano.shared('a')
b = theano.shared('b')
# test that addition works assert not up1 + up2
assert up1
assert up1 + up2
assert not up2
assert len(up1+up2)==1 up1[a] = 5
assert (up1 + up2)[a] == 5
up2[b] = 7 # test that addition works
assert up1 assert up1
assert up1 + up2 assert up1 + up2
assert up2 assert not up2
assert len(up1+up2)==2 assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
assert a in (up1 + up2) up2[b] = 7
assert b in (up1 + up2) assert up1
assert up1 + up2
assert up2
# this works even though there is a collision assert len(up1 + up2) == 2
# because values all match assert (up1 + up2)[a] == 5
assert len(up1 + up1 + up1)==1 assert (up1 + up2)[b] == 7
up2[a] = 8 # a gets different value in up1 and up2 assert a in (up1 + up2)
try: assert b in (up1 + up2)
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right? # this works even though there is a collision
up2[a] = 10 # because values all match
assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2
try:
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right?
up2[a] = 10
...@@ -19,6 +19,15 @@ class Updates(dict): ...@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, SharedVariable): if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论