提交 86a30e71 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #61 from nouiz/curand_fix

Curand fix
......@@ -24,7 +24,7 @@ class CURAND_Base(theano.gof.Op):
This Op uses a generic-typed shared variable to point to a CObject that
encapsulates this opaque reference.
Each random variable is created with a generator of None.
Each random variable is created with a generator of False.
The actual random number generator is allocated from the seed, on the first
call to allocate random numbers (see c_code).
......@@ -64,6 +64,9 @@ class CURAND_Base(theano.gof.Op):
def __hash__(self):
return hash((type(self), self._config()))
def __str__(self):
return self.__class__.__name__+"{inplace=%s, out_dtype=%s}"%(
self.destructive, self.output_type)
def make_node(self, generator, size):
return theano.gof.Apply(self, [generator, size],
......@@ -186,7 +189,7 @@ class CURAND_Base(theano.gof.Op):
%(fail)s;
}
%(o_generator)s = PyCObject_FromVoidPtr(gen, &free_generator);
assert (%(i_generator)s == Py_None);
assert (%(i_generator)s == Py_False);
}
else if (%(destructive)s)
{
......@@ -216,7 +219,7 @@ class CURAND_Base(theano.gof.Op):
""" %locals()
def c_code_cache_version(self):
return (1,)
return (2,)
class CURAND_Normal(CURAND_Base):
......@@ -285,7 +288,7 @@ class CURAND_RandomStreams(object):
else:
msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg
generator = theano.shared(None) #makes a generic
generator = theano.shared(False) #makes a generic
s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed())
......@@ -309,7 +312,7 @@ class CURAND_RandomStreams(object):
else:
msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg
generator = theano.shared(None) #makes a generic
generator = theano.shared(False) #makes a generic
s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论