提交 86a30e71 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #61 from nouiz/curand_fix

Curand fix
...@@ -24,7 +24,7 @@ class CURAND_Base(theano.gof.Op): ...@@ -24,7 +24,7 @@ class CURAND_Base(theano.gof.Op):
This Op uses a generic-typed shared variable to point to a CObject that This Op uses a generic-typed shared variable to point to a CObject that
encapsulates this opaque reference. encapsulates this opaque reference.
Each random variable is created with a generator of None. Each random variable is created with a generator of False.
The actual random number generator is allocated from the seed, on the first The actual random number generator is allocated from the seed, on the first
call to allocate random numbers (see c_code). call to allocate random numbers (see c_code).
...@@ -64,6 +64,9 @@ class CURAND_Base(theano.gof.Op): ...@@ -64,6 +64,9 @@ class CURAND_Base(theano.gof.Op):
def __hash__(self): def __hash__(self):
return hash((type(self), self._config())) return hash((type(self), self._config()))
def __str__(self):
return self.__class__.__name__+"{inplace=%s, out_dtype=%s}"%(
self.destructive, self.output_type)
def make_node(self, generator, size): def make_node(self, generator, size):
return theano.gof.Apply(self, [generator, size], return theano.gof.Apply(self, [generator, size],
...@@ -186,7 +189,7 @@ class CURAND_Base(theano.gof.Op): ...@@ -186,7 +189,7 @@ class CURAND_Base(theano.gof.Op):
%(fail)s; %(fail)s;
} }
%(o_generator)s = PyCObject_FromVoidPtr(gen, &free_generator); %(o_generator)s = PyCObject_FromVoidPtr(gen, &free_generator);
assert (%(i_generator)s == Py_None); assert (%(i_generator)s == Py_False);
} }
else if (%(destructive)s) else if (%(destructive)s)
{ {
...@@ -216,7 +219,7 @@ class CURAND_Base(theano.gof.Op): ...@@ -216,7 +219,7 @@ class CURAND_Base(theano.gof.Op):
""" %locals() """ %locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
class CURAND_Normal(CURAND_Base): class CURAND_Normal(CURAND_Base):
...@@ -285,7 +288,7 @@ class CURAND_RandomStreams(object): ...@@ -285,7 +288,7 @@ class CURAND_RandomStreams(object):
else: else:
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg assert isinstance(size, Variable) and size.ndim==1, msg
generator = theano.shared(None) #makes a generic generator = theano.shared(False) #makes a generic
s_size = theano.tensor.as_tensor_variable(size) s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size, u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed()) self.next_seed())
...@@ -309,7 +312,7 @@ class CURAND_RandomStreams(object): ...@@ -309,7 +312,7 @@ class CURAND_RandomStreams(object):
else: else:
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg assert isinstance(size, Variable) and size.ndim==1, msg
generator = theano.shared(None) #makes a generic generator = theano.shared(False) #makes a generic
s_size = theano.tensor.as_tensor_variable(size) s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size, u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed()) self.next_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论