提交 490ef97a authored 作者: lamblin's avatar lamblin

Merge pull request #710 from nouiz/small

Small
...@@ -693,7 +693,6 @@ class T_Scan(unittest.TestCase): ...@@ -693,7 +693,6 @@ class T_Scan(unittest.TestCase):
outputs_info = [None]) outputs_info = [None])
inp = numpy.arange(5).astype('float64') inp = numpy.arange(5).astype('float64')
rval = theano.function([x], y, updates=updates)(inp) rval = theano.function([x], y, updates=updates)(inp)
import ipdb; ipdb.set_trace()
assert numpy.all(rval == inp[:-1]) assert numpy.all(rval == inp[:-1])
# simple rnn, one input, one state, weights for each; input/state are # simple rnn, one input, one state, weights for each; input/state are
......
...@@ -708,7 +708,7 @@ class ConvOp(Op): ...@@ -708,7 +708,7 @@ class ConvOp(Op):
raise NotImplementedError('todo') raise NotImplementedError('todo')
if self.dx not in (1, 2) or self.dy not in (1, 2): if self.dx not in (1, 2) or self.dy not in (1, 2):
raise Exception("ERROR: We disable ConvOp.grad now when dx or "\ raise NotImplementedError("ERROR: We disable ConvOp.grad now when dx or "\
"dy are different from 1 and 2, as there is a bug in it.") "dy are different from 1 and 2, as there is a bug in it.")
all_shape = self.imshp is not None and self.kshp is not None and \ all_shape = self.imshp is not None and self.kshp is not None and \
......
import unittest
import theano import theano
from theano.updates import Updates from theano.updates import Updates
import theano.tensor as T import theano.tensor as T
def test_updates_setitem(): class test_ifelse(unittest.TestCase):
ok = True
up = Updates() def test_updates_init(self):
sv = theano.shared('asdf') self.assertRaises(TypeError, Updates, dict(d=3))
# keys have to be SharedVariables sv = theano.shared('asdf')
try: Updates({sv:3})
up[5] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables def test_updates_setitem(self):
try:
up[T.vector()] = 7
ok = False
except TypeError:
ok = True ok = True
assert ok
# keys have to be SharedVariables
up[theano.shared(88)] = 7
def test_updates_add(): up = Updates()
sv = theano.shared('asdf')
up1 = Updates() # keys have to be SharedVariables
up2 = Updates() self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
a = theano.shared('a') up[theano.shared(88)] = 7
b = theano.shared('b')
def test_updates_add(self):
assert not up1 + up2 up1 = Updates()
up2 = Updates()
up1[a] = 5 a = theano.shared('a')
b = theano.shared('b')
# test that addition works assert not up1 + up2
assert up1
assert up1 + up2
assert not up2
assert len(up1+up2)==1 up1[a] = 5
assert (up1 + up2)[a] == 5
up2[b] = 7 # test that addition works
assert up1 assert up1
assert up1 + up2 assert up1 + up2
assert up2 assert not up2
assert len(up1+up2)==2 assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
assert a in (up1 + up2) up2[b] = 7
assert b in (up1 + up2) assert up1
assert up1 + up2
assert up2
# this works even though there is a collision assert len(up1 + up2) == 2
# because values all match assert (up1 + up2)[a] == 5
assert len(up1 + up1 + up1)==1 assert (up1 + up2)[b] == 7
up2[a] = 8 # a gets different value in up1 and up2 assert a in (up1 + up2)
try: assert b in (up1 + up2)
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right? # this works even though there is a collision
up2[a] = 10 # because values all match
assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2
try:
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right?
up2[a] = 10
...@@ -19,6 +19,15 @@ class Updates(dict): ...@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, SharedVariable): if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论