提交 2d6f8566 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add support for non-scalar random distributions, by adding a new keyword

argument "ndim_added" to random_function. Use that mechanism for "multinomial" and "permutation" distributions.
上级 f39f3d69
...@@ -64,7 +64,13 @@ class RandomFunction(gof.Op): ...@@ -64,7 +64,13 @@ class RandomFunction(gof.Op):
:param args: a list of default arguments for the function :param args: a list of default arguments for the function
:param kwargs: if the 'inplace' key is there, its value will be used to determine if the op operates inplace or not :param kwargs:
If the 'inplace' key is there, its value will be used to
determine if the op operates inplace or not.
If the 'ndim_added' key is there, its value indicates how
many more dimensions this op will add to the output, in
addition to the shape's dimensions (used in multinomial and
permutation).
""" """
self.__setstate__([fn, outtype, args, kwargs]) self.__setstate__([fn, outtype, args, kwargs])
...@@ -73,11 +79,13 @@ class RandomFunction(gof.Op): ...@@ -73,11 +79,13 @@ class RandomFunction(gof.Op):
and self.fn == other.fn\ and self.fn == other.fn\
and self.outtype == other.outtype\ and self.outtype == other.outtype\
and self.args == other.args\ and self.args == other.args\
and self.inplace == other.inplace and self.inplace == other.inplace\
and self.ndim_added == other.ndim_added
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.fn) \ return hash(type(self)) ^ hash(self.fn) \
^ hash(self.outtype) ^ hash(self.args) ^ hash(self.inplace) ^ hash(self.outtype) ^ hash(self.args)\
^ hash(self.inplace) ^ hash(self.ndim_added)
def __getstate__(self): def __getstate__(self):
return self.state return self.state
...@@ -96,15 +104,19 @@ class RandomFunction(gof.Op): ...@@ -96,15 +104,19 @@ class RandomFunction(gof.Op):
self.inplace = kwargs.pop('inplace', False) self.inplace = kwargs.pop('inplace', False)
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.ndim_added = kwargs.pop('ndim_added', 0)
def make_node(self, r, shape, *args): def make_node(self, r, shape, *args):
""" """
:param r: a numpy.RandomState instance, or a Variable of Type RandomStateType that will :param r: a numpy.RandomState instance, or a Variable of Type RandomStateType that will
contain a RandomState instance. contain a RandomState instance.
:param shape: an lvector with the shape of the tensor output by this Op. At runtime, :param shape: an lvector with a shape defining how many samples to draw.
the value associated with this lvector must have a length that matches the number of In the case of scalar distributions, it is the shape of the tensor output by this Op.
dimensions promised by `self.outtype`. In that case, at runtime, the value associated with this lvector must have a length
equal to the number of dimensions promised by `self.outtype`.
In general, the number of output dimenstions is equal to
len(self.outtype)+self.ndim_added.
:param args: the values associated with these variables will be passed to the RandomState :param args: the values associated with these variables will be passed to the RandomState
function during perform as extra "*args"-style arguments. These should be castable to function during perform as extra "*args"-style arguments. These should be castable to
...@@ -167,7 +179,7 @@ class RandomFunction(gof.Op): ...@@ -167,7 +179,7 @@ class RandomFunction(gof.Op):
r, shape, args = inputs[0], inputs[1], inputs[2:] r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState assert type(r) == numpy.random.RandomState
r_orig = r r_orig = r
assert self.outtype.ndim == len(shape) assert self.outtype.ndim == len(shape) + self.ndim_added
if not self.inplace: if not self.inplace:
r = copy(r) r = copy(r)
rout[0] = r rout[0] = r
...@@ -204,6 +216,10 @@ def random_function(fn, dtype, *rfargs, **rfkwargs): ...@@ -204,6 +216,10 @@ def random_function(fn, dtype, *rfargs, **rfkwargs):
the user can give an integer as first argument, which will be interpreted as the the user can give an integer as first argument, which will be interpreted as the
number of dimensions. number of dimensions.
If the distribution is not scalar (e.g., a multinomial), the output will have
more dimensions than what the shape argument suggests. The "ndim_added" keyword
arguments allows to specify how many dimensions to add (for a multinomial, 1).
The number of dimensions for the following shape arguments can be inferred: The number of dimensions for the following shape arguments can be inferred:
- shape(x) - shape(x)
- make_lvector(x, y, z, ...) - make_lvector(x, y, z, ...)
...@@ -224,6 +240,8 @@ def random_function(fn, dtype, *rfargs, **rfkwargs): ...@@ -224,6 +240,8 @@ def random_function(fn, dtype, *rfargs, **rfkwargs):
if ndim is None: if ndim is None:
raise ValueError('Cannot infer the number of dimensions from the shape argument.') raise ValueError('Cannot infer the number of dimensions from the shape argument.')
# note: rf could be cached for future use # note: rf could be cached for future use
ndim_added = rfkwargs.get('ndim_added', 0)
ndim += ndim_added
rf = RandomFunction(fn, tensor.TensorType(dtype = dtype, broadcastable = (False,)*ndim), *rfargs, **rfkwargs) rf = RandomFunction(fn, tensor.TensorType(dtype = dtype, broadcastable = (False,)*ndim), *rfargs, **rfkwargs)
return rf(r, shape, *args, **kwargs) return rf(r, shape, *args, **kwargs)
return f return f
...@@ -296,16 +314,30 @@ def permutation_helper(random_state, n, shape): ...@@ -296,16 +314,30 @@ def permutation_helper(random_state, n, shape):
out[i] = random_state.permutation(n) out[i] = random_state.permutation(n)
return out return out
permutation = random_function(permutation_helper, 'int64', 1) permutation = random_function(permutation_helper, 'int64', 1, ndim_added=1)
permutation.__doc__ = """ permutation.__doc__ = """
Usage: permutation(random_state, size, n) Usage: permutation(random_state, size, n)
Returns a permutation of the integers between 0 and n-1. Returns permutations of the integers between 0 and n-1, as many times
as required by size. For instance, if size=(p,q), p*q permutations
will be generated, and the output shape will be (p,q,n), because each
permutation is of size n.
If the size argument is ambiguous on the number of dimensions, the first
argument may be a plain integer i, which should correspond to len(size).
Note that the output will then be of dimension i+1.
""" """
multinomial = random_function('multinomial', 'float64', 1, [0.5, 0.5]) multinomial = random_function('multinomial', 'float64', 1, [0.5, 0.5], ndim_added=1)
multinomial.__doc__ = """ multinomial.__doc__ = """
Usage: multinomial(random_state, size, n, pvals) Usage: multinomial(random_state, size, n, pvals)
Sample from a multinomial distribution defined by probabilities pvals.
Sample from a multinomial distribution defined by probabilities pvals,
as many times as required by size. For instance, if size=(p,q), p*q
samples will be drawn, and the output shape will be (p,q,len(pvals)).
If the size argument is ambiguous on the number of dimensions, the first
argument may be a plain integer i, which should correspond to len(size).
Note that the output will then be of dimension i+1.
""" """
...@@ -313,7 +345,8 @@ Sample from a multinomial distribution defined by probabilities pvals. ...@@ -313,7 +345,8 @@ Sample from a multinomial distribution defined by probabilities pvals.
def random_make_inplace(node): def random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs opkwargs = dict(inplace=True, ndim_added=op.ndim_added)
return RandomFunction(op.fn, op.outtype, *op.args, **opkwargs).make_node(*node.inputs).outputs
return False return False
optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace') optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论