提交 490ef97a authored 作者: lamblin's avatar lamblin

Merge pull request #710 from nouiz/small

Small
...@@ -693,7 +693,6 @@ class T_Scan(unittest.TestCase): ...@@ -693,7 +693,6 @@ class T_Scan(unittest.TestCase):
outputs_info = [None]) outputs_info = [None])
inp = numpy.arange(5).astype('float64') inp = numpy.arange(5).astype('float64')
rval = theano.function([x], y, updates=updates)(inp) rval = theano.function([x], y, updates=updates)(inp)
import ipdb; ipdb.set_trace()
assert numpy.all(rval == inp[:-1]) assert numpy.all(rval == inp[:-1])
# simple rnn, one input, one state, weights for each; input/state are # simple rnn, one input, one state, weights for each; input/state are
......
...@@ -708,7 +708,7 @@ class ConvOp(Op): ...@@ -708,7 +708,7 @@ class ConvOp(Op):
raise NotImplementedError('todo') raise NotImplementedError('todo')
if self.dx not in (1, 2) or self.dy not in (1, 2): if self.dx not in (1, 2) or self.dy not in (1, 2):
raise Exception("ERROR: We disable ConvOp.grad now when dx or "\ raise NotImplementedError("ERROR: We disable ConvOp.grad now when dx or "\
"dy are different from 1 and 2, as there is a bug in it.") "dy are different from 1 and 2, as there is a bug in it.")
all_shape = self.imshp is not None and self.kshp is not None and \ all_shape = self.imshp is not None and self.kshp is not None and \
......
import unittest
import theano import theano
from theano.updates import Updates from theano.updates import Updates
import theano.tensor as T import theano.tensor as T
def test_updates_setitem(): class test_ifelse(unittest.TestCase):
ok = True
def test_updates_init(self):
self.assertRaises(TypeError, Updates, dict(d=3))
up = Updates()
sv = theano.shared('asdf') sv = theano.shared('asdf')
Updates({sv:3})
# keys have to be SharedVariables def test_updates_setitem(self):
try:
up[5] = 7
ok = False
except TypeError:
ok = True ok = True
assert ok
# keys have to be SharedVariables up = Updates()
try: sv = theano.shared('asdf')
up[T.vector()] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables # keys have to be SharedVariables
up[theano.shared(88)] = 7 self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
up[theano.shared(88)] = 7
def test_updates_add(): def test_updates_add(self):
up1 = Updates() up1 = Updates()
up2 = Updates() up2 = Updates()
...@@ -37,7 +33,6 @@ def test_updates_add(): ...@@ -37,7 +33,6 @@ def test_updates_add():
a = theano.shared('a') a = theano.shared('a')
b = theano.shared('b') b = theano.shared('b')
assert not up1 + up2 assert not up1 + up2
up1[a] = 5 up1[a] = 5
...@@ -47,7 +42,7 @@ def test_updates_add(): ...@@ -47,7 +42,7 @@ def test_updates_add():
assert up1 + up2 assert up1 + up2
assert not up2 assert not up2
assert len(up1+up2)==1 assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
up2[b] = 7 up2[b] = 7
...@@ -55,7 +50,7 @@ def test_updates_add(): ...@@ -55,7 +50,7 @@ def test_updates_add():
assert up1 + up2 assert up1 + up2
assert up2 assert up2
assert len(up1+up2)==2 assert len(up1 + up2) == 2
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7 assert (up1 + up2)[b] == 7
...@@ -64,7 +59,7 @@ def test_updates_add(): ...@@ -64,7 +59,7 @@ def test_updates_add():
# this works even though there is a collision # this works even though there is a collision
# because values all match # because values all match
assert len(up1 + up1 + up1)==1 assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2 up2[a] = 8 # a gets different value in up1 and up2
try: try:
...@@ -75,5 +70,3 @@ def test_updates_add(): ...@@ -75,5 +70,3 @@ def test_updates_add():
# reassigning to a key works fine right? # reassigning to a key works fine right?
up2[a] = 10 up2[a] = 10
...@@ -19,6 +19,15 @@ class Updates(dict): ...@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, SharedVariable): if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论