提交 fea13d01 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add keyword arg "dtype" to random functions.

上级 28e1bd43
...@@ -51,7 +51,7 @@ class RandomFunction(gof.Op): ...@@ -51,7 +51,7 @@ class RandomFunction(gof.Op):
""" """
def __init__(self, fn, outtype, inplace=False, ndim_added=0 ): def __init__(self, fn, outtype, inplace=False, ndim_added=0):
""" """
:param fn: a member function of numpy.RandomState :param fn: a member function of numpy.RandomState
Technically, any function with a signature like the ones in Technically, any function with a signature like the ones in
...@@ -306,7 +306,7 @@ def _generate_broadcasting_indices(out_shape, *shapes): ...@@ -306,7 +306,7 @@ def _generate_broadcasting_indices(out_shape, *shapes):
return ret_indices return ret_indices
def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None): def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype=theano.config.floatX):
""" """
Sample from a uniform distribution between low and high. Sample from a uniform distribution between low and high.
...@@ -320,10 +320,10 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None): ...@@ -320,10 +320,10 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None):
high = tensor.as_tensor_variable(high) high = tensor.as_tensor_variable(high)
ndim, size = _infer_ndim(ndim, size, low, high) ndim, size = _infer_ndim(ndim, size, low, high)
op = RandomFunction('uniform', op = RandomFunction('uniform',
tensor.TensorType(dtype = 'float64', broadcastable = (False,)*ndim) ) tensor.TensorType(dtype = dtype, broadcastable = (False,)*ndim) )
return op(random_state, size, low, high) return op(random_state, size, low, high)
def binomial(random_state, size=None, n=1, prob=0.5, ndim=None): def binomial(random_state, size=None, n=1, prob=0.5, ndim=None, dtype='int64'):
""" """
Sample n times with probability of success prob for each trial, Sample n times with probability of success prob for each trial,
return the number of successes. return the number of successes.
...@@ -338,10 +338,10 @@ def binomial(random_state, size=None, n=1, prob=0.5, ndim=None): ...@@ -338,10 +338,10 @@ def binomial(random_state, size=None, n=1, prob=0.5, ndim=None):
prob = tensor.as_tensor_variable(prob) prob = tensor.as_tensor_variable(prob)
ndim, size = _infer_ndim(ndim, size, n, prob) ndim, size = _infer_ndim(ndim, size, n, prob)
op = RandomFunction('binomial', op = RandomFunction('binomial',
tensor.TensorType(dtype = 'int64', broadcastable = (False,)*ndim) ) tensor.TensorType(dtype = dtype, broadcastable = (False,)*ndim) )
return op(random_state, size, n, prob) return op(random_state, size, n, prob)
def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None): def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype=theano.config.floatX):
""" """
Sample from a normal distribution centered on avg with Sample from a normal distribution centered on avg with
the specified standard deviation (std). the specified standard deviation (std).
...@@ -356,7 +356,7 @@ def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None): ...@@ -356,7 +356,7 @@ def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None):
std = tensor.as_tensor_variable(std) std = tensor.as_tensor_variable(std)
ndim, size = _infer_ndim(ndim, size, avg, std) ndim, size = _infer_ndim(ndim, size, avg, std)
op = RandomFunction('normal', op = RandomFunction('normal',
tensor.TensorType(dtype = 'float64', broadcastable = (False,)*ndim) ) tensor.TensorType(dtype = dtype, broadcastable = (False,)*ndim) )
return op(random_state, size, avg, std) return op(random_state, size, avg, std)
def random_integers_helper(random_state, low, high, size): def random_integers_helper(random_state, low, high, size):
...@@ -401,7 +401,7 @@ def random_integers_helper(random_state, low, high, size): ...@@ -401,7 +401,7 @@ def random_integers_helper(random_state, low, high, size):
return out return out
def random_integers(random_state, size=None, low=0, high=1, ndim=None): def random_integers(random_state, size=None, low=0, high=1, ndim=None, dtype='int64'):
""" """
Sample a random integer between low and high, both inclusive. Sample a random integer between low and high, both inclusive.
...@@ -415,7 +415,7 @@ def random_integers(random_state, size=None, low=0, high=1, ndim=None): ...@@ -415,7 +415,7 @@ def random_integers(random_state, size=None, low=0, high=1, ndim=None):
high = tensor.as_tensor_variable(high) high = tensor.as_tensor_variable(high)
ndim, size = _infer_ndim(ndim, size, low, high) ndim, size = _infer_ndim(ndim, size, low, high)
op = RandomFunction(random_integers_helper, op = RandomFunction(random_integers_helper,
tensor.TensorType(dtype = 'int64', broadcastable = (False,)*ndim) ) tensor.TensorType(dtype = dtype, broadcastable = (False,)*ndim) )
return op(random_state, size, low, high) return op(random_state, size, low, high)
def permutation_helper(random_state, n, shape): def permutation_helper(random_state, n, shape):
...@@ -448,7 +448,7 @@ def permutation_helper(random_state, n, shape): ...@@ -448,7 +448,7 @@ def permutation_helper(random_state, n, shape):
#print 'RETURNING', out.shape #print 'RETURNING', out.shape
return out return out
def permutation(random_state, size=None, n=1, ndim=None): def permutation(random_state, size=None, n=1, ndim=None, dtype='int64'):
""" """
Returns permutations of the integers between 0 and n-1, as many times Returns permutations of the integers between 0 and n-1, as many times
as required by size. For instance, if size=(p,q), p*q permutations as required by size. For instance, if size=(p,q), p*q permutations
...@@ -465,7 +465,7 @@ def permutation(random_state, size=None, n=1, ndim=None): ...@@ -465,7 +465,7 @@ def permutation(random_state, size=None, n=1, ndim=None):
ndim, size = _infer_ndim(ndim, size) ndim, size = _infer_ndim(ndim, size)
#print "NDIM", ndim, size #print "NDIM", ndim, size
op = RandomFunction(permutation_helper, op = RandomFunction(permutation_helper,
tensor.TensorType(dtype='int64', broadcastable=(False,)*(ndim+1)), tensor.TensorType(dtype=dtype, broadcastable=(False,)*(ndim+1)),
ndim_added=1) ndim_added=1)
return op(random_state, size, n) return op(random_state, size, n)
...@@ -517,7 +517,7 @@ def multinomial_helper(random_state, n, pvals, size): ...@@ -517,7 +517,7 @@ def multinomial_helper(random_state, n, pvals, size):
out[mi] = random_state.multinomial(n=n[ni], pvals=pvals[pi]) out[mi] = random_state.multinomial(n=n[ni], pvals=pvals[pi])
return out return out
def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5], ndim=None): def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5], ndim=None, dtype='int64'):
""" """
Sample n times from a multinomial distribution defined by Sample n times from a multinomial distribution defined by
probabilities pvals, as many times as required by size. For probabilities pvals, as many times as required by size. For
...@@ -554,7 +554,7 @@ optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_new ...@@ -554,7 +554,7 @@ optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_new
class RandomStreamsBase(object): class RandomStreamsBase(object):
def binomial(self, size=None, n=1, prob=0.5, ndim=None): def binomial(self, size=None, n=1, prob=0.5, ndim=None, dtype='int64'):
""" """
Sample n times with probability of success prob for each trial, Sample n times with probability of success prob for each trial,
return the number of successes. return the number of successes.
...@@ -563,9 +563,9 @@ class RandomStreamsBase(object): ...@@ -563,9 +563,9 @@ class RandomStreamsBase(object):
ndim may be a plain integer to supplement the missing ndim may be a plain integer to supplement the missing
information. information.
""" """
return self.gen(binomial, size, n, prob, ndim=ndim) return self.gen(binomial, size, n, prob, ndim=ndim, dtype=dtype)
def uniform(self, size=None, low=0.0, high=1.0, ndim=None): def uniform(self, size=None, low=0.0, high=1.0, ndim=None, dtype=theano.config.floatX):
""" """
Sample a tensor of given size whose element from a uniform Sample a tensor of given size whose element from a uniform
distribution between low and high. distribution between low and high.
...@@ -574,9 +574,9 @@ class RandomStreamsBase(object): ...@@ -574,9 +574,9 @@ class RandomStreamsBase(object):
ndim may be a plain integer to supplement the missing ndim may be a plain integer to supplement the missing
information. information.
""" """
return self.gen(uniform, size, low, high, ndim=ndim) return self.gen(uniform, size, low, high, ndim=ndim, dtype=dtype)
def normal(self, size=None, avg=0.0, std=1.0, ndim=None): def normal(self, size=None, avg=0.0, std=1.0, ndim=None, dtype=theano.config.floatX):
""" """
Sample from a normal distribution centered on avg with Sample from a normal distribution centered on avg with
the specified standard deviation (std). the specified standard deviation (std).
...@@ -585,9 +585,9 @@ class RandomStreamsBase(object): ...@@ -585,9 +585,9 @@ class RandomStreamsBase(object):
ndim may be a plain integer to supplement the missing ndim may be a plain integer to supplement the missing
information. information.
""" """
return self.gen(normal, size, avg, std, ndim=ndim) return self.gen(normal, size, avg, std, ndim=ndim, dtype=dtype)
def random_integers(self, size=None, low=0, high=1, ndim=None): def random_integers(self, size=None, low=0, high=1, ndim=None, dtype='int64'):
""" """
Sample a random integer between low and high, both inclusive. Sample a random integer between low and high, both inclusive.
...@@ -595,9 +595,9 @@ class RandomStreamsBase(object): ...@@ -595,9 +595,9 @@ class RandomStreamsBase(object):
ndim may be a plain integer to supplement the missing ndim may be a plain integer to supplement the missing
information. information.
""" """
return self.gen(random_integers, size, low, high, ndim=ndim) return self.gen(random_integers, size, low, high, ndim=ndim, dtype=dtype)
def permutation(self, size=None, n=1, ndim=None): def permutation(self, size=None, n=1, ndim=None, dtype='int64'):
""" """
Returns permutations of the integers between 0 and n-1, as many times Returns permutations of the integers between 0 and n-1, as many times
as required by size. For instance, if size=(p,q), p*q permutations as required by size. For instance, if size=(p,q), p*q permutations
...@@ -611,9 +611,9 @@ class RandomStreamsBase(object): ...@@ -611,9 +611,9 @@ class RandomStreamsBase(object):
.. note:: .. note::
Note that the output will then be of dimension ndim+1. Note that the output will then be of dimension ndim+1.
""" """
return self.gen(permutation, size, n, ndim=ndim) return self.gen(permutation, size, n, ndim=ndim, dtype=dtype)
def multinomial(self, size=None, n=1, pvals=[0.5, 0.5], ndim=None): def multinomial(self, size=None, n=1, pvals=[0.5, 0.5], ndim=None, dtype='int64'):
""" """
Sample n times from a multinomial distribution defined by Sample n times from a multinomial distribution defined by
probabilities pvals, as many times as required by size. For probabilities pvals, as many times as required by size. For
...@@ -627,7 +627,7 @@ class RandomStreamsBase(object): ...@@ -627,7 +627,7 @@ class RandomStreamsBase(object):
.. note:: .. note::
Note that the output will then be of dimension ndim+1. Note that the output will then be of dimension ndim+1.
""" """
return self.gen(multinomial, size, n, pvals, ndim=ndim) return self.gen(multinomial, size, n, pvals, ndim=ndim, dtype=dtype)
def shuffle_row_elements(self, input): def shuffle_row_elements(self, input):
"""Return a variable with every row (rightmost index) shuffled. """Return a variable with every row (rightmost index) shuffled.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论