提交 b1ff1b33 authored 作者: Frederic's avatar Frederic

The Updates object now check that the key are SharedVariable when we pass them…

The Updates object now check that the key are SharedVariable when we pass them in the __init__ function.
上级 c3f6b9fb
import unittest
import theano
from theano.updates import Updates
import theano.tensor as T
def test_updates_setitem():
ok = True
class test_ifelse(unittest.TestCase):
def test_updates_init(self):
self.assertRaises(TypeError, Updates, dict(d=3))
up = Updates()
sv = theano.shared('asdf')
Updates({sv:3})
# keys have to be SharedVariables
try:
up[5] = 7
ok = False
except TypeError:
def test_updates_setitem(self):
ok = True
assert ok
# keys have to be SharedVariables
try:
up[T.vector()] = 7
ok = False
except TypeError:
ok = True
assert ok
up = Updates()
sv = theano.shared('asdf')
# keys have to be SharedVariables
up[theano.shared(88)] = 7
self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
up[theano.shared(88)] = 7
def test_updates_add():
def test_updates_add(self):
up1 = Updates()
up2 = Updates()
......@@ -37,7 +33,6 @@ def test_updates_add():
a = theano.shared('a')
b = theano.shared('b')
assert not up1 + up2
up1[a] = 5
......@@ -47,7 +42,7 @@ def test_updates_add():
assert up1 + up2
assert not up2
assert len(up1+up2)==1
assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5
up2[b] = 7
......@@ -55,7 +50,7 @@ def test_updates_add():
assert up1 + up2
assert up2
assert len(up1+up2)==2
assert len(up1 + up2) == 2
assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
......@@ -64,7 +59,7 @@ def test_updates_add():
# this works even though there is a collision
# because values all match
assert len(up1 + up1 + up1)==1
assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2
try:
......@@ -75,5 +70,3 @@ def test_updates_add():
# reassigning to a key works fine right?
up2[a] = 10
......@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates.
"""
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value):
if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论