提交 b1ff1b33 authored 作者: Frederic's avatar Frederic

The Updates object now check that the key are SharedVariable when we pass them…

The Updates object now check that the key are SharedVariable when we pass them in the __init__ function.
上级 c3f6b9fb
import unittest
import theano
from theano.updates import Updates
import theano.tensor as T
def test_updates_setitem():
ok = True
class test_ifelse(unittest.TestCase):
up = Updates()
sv = theano.shared('asdf')
def test_updates_init(self):
self.assertRaises(TypeError, Updates, dict(d=3))
# keys have to be SharedVariables
try:
up[5] = 7
ok = False
except TypeError:
ok = True
assert ok
sv = theano.shared('asdf')
Updates({sv:3})
# keys have to be SharedVariables
try:
up[T.vector()] = 7
ok = False
except TypeError:
def test_updates_setitem(self):
ok = True
assert ok
# keys have to be SharedVariables
up[theano.shared(88)] = 7
def test_updates_add():
up = Updates()
sv = theano.shared('asdf')
up1 = Updates()
up2 = Updates()
# keys have to be SharedVariables
self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
a = theano.shared('a')
b = theano.shared('b')
up[theano.shared(88)] = 7
def test_updates_add(self):
assert not up1 + up2
up1 = Updates()
up2 = Updates()
up1[a] = 5
a = theano.shared('a')
b = theano.shared('b')
# test that addition works
assert up1
assert up1 + up2
assert not up2
assert not up1 + up2
assert len(up1+up2)==1
assert (up1 + up2)[a] == 5
up1[a] = 5
up2[b] = 7
assert up1
assert up1 + up2
assert up2
# test that addition works
assert up1
assert up1 + up2
assert not up2
assert len(up1+up2)==2
assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5
assert a in (up1 + up2)
assert b in (up1 + up2)
up2[b] = 7
assert up1
assert up1 + up2
assert up2
# this works even though there is a collision
# because values all match
assert len(up1 + up1 + up1)==1
assert len(up1 + up2) == 2
assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
up2[a] = 8 # a gets different value in up1 and up2
try:
up1 + up2
assert 0
except KeyError:
pass
assert a in (up1 + up2)
assert b in (up1 + up2)
# reassigning to a key works fine right?
up2[a] = 10
# this works even though there is a collision
# because values all match
assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2
try:
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right?
up2[a] = 10
......@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates.
"""
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value):
if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论