提交 a3ae249e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add shuffle_rows function to RandomStreams, inspired from numpy.random.shuffle

上级 584754da
......@@ -6,7 +6,7 @@ import numpy
from theano.compile import module, In, Component
from theano.gof import Container
from theano.tensor import raw_random
from theano.tensor import raw_random, reorder_row_elements
class RandomStreamsInstance(object):
"""RandomStreamsInstance"""
......@@ -189,3 +189,10 @@ class RandomStreams(Component):
"""
return self.gen(raw_random.multinomial, *args, **kwargs)
def shuffle_rows(self, input):
"""Return a variable with every row (rightmost index) shuffled"""
perm = self.permutation(input.ndim-1, input.shape[:-1], input.shape[-1])
shuffled = reorder_row_elements(input, perm)
return shuffled
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论