提交 b1ff1b33 authored 作者: Frederic's avatar Frederic

The Updates object now check that the key are SharedVariable when we pass them…

The Updates object now check that the key are SharedVariable when we pass them in the __init__ function.
上级 c3f6b9fb
import unittest
import theano import theano
from theano.updates import Updates from theano.updates import Updates
import theano.tensor as T import theano.tensor as T
def test_updates_setitem(): class test_ifelse(unittest.TestCase):
ok = True
def test_updates_init(self):
self.assertRaises(TypeError, Updates, dict(d=3))
up = Updates()
sv = theano.shared('asdf') sv = theano.shared('asdf')
Updates({sv:3})
# keys have to be SharedVariables def test_updates_setitem(self):
try:
up[5] = 7
ok = False
except TypeError:
ok = True ok = True
assert ok
# keys have to be SharedVariables up = Updates()
try: sv = theano.shared('asdf')
up[T.vector()] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables # keys have to be SharedVariables
up[theano.shared(88)] = 7 self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
up[theano.shared(88)] = 7
def test_updates_add(): def test_updates_add(self):
up1 = Updates() up1 = Updates()
up2 = Updates() up2 = Updates()
...@@ -37,7 +33,6 @@ def test_updates_add(): ...@@ -37,7 +33,6 @@ def test_updates_add():
a = theano.shared('a') a = theano.shared('a')
b = theano.shared('b') b = theano.shared('b')
assert not up1 + up2 assert not up1 + up2
up1[a] = 5 up1[a] = 5
...@@ -47,7 +42,7 @@ def test_updates_add(): ...@@ -47,7 +42,7 @@ def test_updates_add():
assert up1 + up2 assert up1 + up2
assert not up2 assert not up2
assert len(up1+up2)==1 assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
up2[b] = 7 up2[b] = 7
...@@ -55,7 +50,7 @@ def test_updates_add(): ...@@ -55,7 +50,7 @@ def test_updates_add():
assert up1 + up2 assert up1 + up2
assert up2 assert up2
assert len(up1+up2)==2 assert len(up1 + up2) == 2
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7 assert (up1 + up2)[b] == 7
...@@ -64,7 +59,7 @@ def test_updates_add(): ...@@ -64,7 +59,7 @@ def test_updates_add():
# this works even though there is a collision # this works even though there is a collision
# because values all match # because values all match
assert len(up1 + up1 + up1)==1 assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2 up2[a] = 8 # a gets different value in up1 and up2
try: try:
...@@ -75,5 +70,3 @@ def test_updates_add(): ...@@ -75,5 +70,3 @@ def test_updates_add():
# reassigning to a key works fine right? # reassigning to a key works fine right?
up2[a] = 10 up2[a] = 10
...@@ -19,6 +19,15 @@ class Updates(dict): ...@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, SharedVariable): if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论