Fixed typo which was causing an error (undefined variable). Please test before

committing ! :(
上级 0fe44b63
...@@ -865,7 +865,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -865,7 +865,6 @@ class _Linker(gof.link.LocalLinker):
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0])
#if r in r_vals:
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set, _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=True) clobber_dr_vals=True)
......
...@@ -648,7 +648,7 @@ class First(BinaryScalarOp): ...@@ -648,7 +648,7 @@ class First(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz if x.type in grad_type else None, None return gz if x.type in grad_types else None, None
first = First(transfer_type(0), name = 'first') first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
...@@ -668,7 +668,7 @@ class Identity(UnaryScalarOp): ...@@ -668,7 +668,7 @@ class Identity(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz if x.type in grad_type else None, return gz if x.type in grad_types else None,
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
......
...@@ -172,3 +172,15 @@ class RandomStreams(Component): ...@@ -172,3 +172,15 @@ class RandomStreams(Component):
""" """
return self.gen(raw_random.random_integers, *args, **kwargs) return self.gen(raw_random.random_integers, *args, **kwargs)
randstream_singleton = []
def getRandomStream(seed=None, force_new=False):
global randstream_singleton
if force_new or not randstream_singleton:
print 'creating random stream with seed %i' % seed
randstream_singleton = []
randstream_singleton.append(RandomStreams(seed))
elif seed:
print >> sys.stderr, 'Warning: RandomStream singleton already instantiated. Seed %i will be ignored'%seed
return randstream_singleton[0]
...@@ -3,10 +3,10 @@ __docformat__ = "restructuredtext en" ...@@ -3,10 +3,10 @@ __docformat__ = "restructuredtext en"
import sys import sys
import unittest import unittest
import numpy import numpy
from theano.tensor.randomstreams import RandomStreams, raw_random, getRandomStream, randstream_singleton
from theano.tensor.randomstreams import RandomStreams, raw_random
from theano.compile import Module, Method, Member from theano.compile import Module, Method, Member
import theano
from theano import tensor from theano import tensor
from theano import compile, gof from theano import compile, gof
...@@ -139,6 +139,40 @@ class T_RandomStreams(unittest.TestCase): ...@@ -139,6 +139,40 @@ class T_RandomStreams(unittest.TestCase):
def test_singleton(self):
moda = Module()
moda.randa = getRandomStream(12)
a = moda.randa.uniform((2,2))
moda.fn = Method([], a)
imoda = moda.make()
imoda.randa.initialize()
modb = Module()
modb.randb = getRandomStream()
b = modb.randb.uniform((2,2))
modb.fn = Method([], b)
imodb = modb.make()
imodb.randb.initialize()
avals1 = imoda.fn()
bvals1 = imodb.fn()
modc = Module()
modc.randc = getRandomStream(12, force_new=True)
a2 = modc.randc.uniform((2,2))
b2 = modc.randc.uniform((2,2))
modc.fna = Method([], a2)
modc.fnb = Method([], b2)
imodc = modc.make()
imodc.randc.initialize()
avals2 = imodc.fna()
bvals2 = imodc.fnb()
assert (avals1 == avals2).all()
assert (bvals1 == bvals2).all()
if __name__ == '__main__': if __name__ == '__main__':
from theano.tests import main from theano.tests import main
main("test_randomstreams") main("test_randomstreams")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论