提交 caa8e40d authored 作者: James Bergstra's avatar James Bergstra

new feature: Updates

上级 adc96a96
......@@ -74,6 +74,8 @@ from printing import \
import scan_module
from scan_module import scan, map, reduce, foldl, foldr, clone
from updates import Updates
import tensor
import scalar
#import sparse #we don't import by default as we don't want to force having scipy installed.
......
......@@ -53,6 +53,7 @@ from scan_op import safe_new, safe_to_cpu
import scan_utils
from scan_utils import safe_new, safe_to_cpu, traverse
from theano.sandbox import cuda
from theano.updates import Updates
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan')
......@@ -917,7 +918,7 @@ def scan( fn
### and so on ...
##
update_map = {}
update_map = Updates()
def remove_dimensions( outs, steps_return, offsets = None):
out_ls = []
for idx, out in enumerate(outs):
......
import theano
from theano.updates import Updates
import theano.tensor as T
def test_updates_setitem():
ok = True
up = Updates()
sv = theano.shared('asdf')
# keys have to be SharedVariables
try:
up[5] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables
try:
up[T.vector()] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables
up[theano.shared(88)] = 7
def test_updates_add():
up1 = Updates()
up2 = Updates()
a = theano.shared('a')
b = theano.shared('b')
assert not up1 + up2
up1[a] = 5
# test that addition works
assert up1
assert up1 + up2
assert not up2
assert len(up1+up2)==1
assert (up1 + up2)[a] == 5
up2[b] = 7
assert up1
assert up1 + up2
assert up2
assert len(up1+up2)==2
assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
assert a in (up1 + up2)
assert b in (up1 + up2)
# this works even though there is a collision
# because values all match
assert len(up1 + up1 + up1)==1
up2[a] = 8 # a gets different value in up1 and up2
try:
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right?
up2[a] = 10
"""Defines Updates object for storing a (SharedVariable, new_value) mapping.
"""
__authors__ = "theano-dev"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
from theano.compile.sharedvalue import SharedVariable
import logging
logger = logging.getLogger('theano.updates')
class Updates(dict):
"""
Dict-like mapping from SharedVariable keys to their new values.
This mapping supports the use of the "+" operator for the union of updates.
"""
def __setitem__(self, key, value):
if isinstance(key, SharedVariable):
#TODO: consider doing error-checking on value.
# insist that it is a Theano variable? Have the right type?
# This could have weird consequences - for example a
# GPU SharedVariable is customarily associated with a TensorType
# value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately?
return super(Updates, self).__setitem__(key, value)
else:
raise TypeError('Updates keys must inherit from SharedVariable', key)
def update(self, other):
for key, val in dict(other).iteritems():
if key in self:
if self[key] == val:
continue
raise KeyError('Collision', key)
self[key] = val # __setitem__ does type-checking
def __add__(self, other):
rval = Updates()
rval.update(self)
rval.update(other)
return rval
def __radd__(other, self):
rval = Updates()
rval.update(other)
rval.update(self)
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论