提交 a3ae249e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add shuffle_rows function to RandomStreams, inspired from numpy.random.shuffle

上级 584754da
...@@ -6,7 +6,7 @@ import numpy ...@@ -6,7 +6,7 @@ import numpy
from theano.compile import module, In, Component from theano.compile import module, In, Component
from theano.gof import Container from theano.gof import Container
from theano.tensor import raw_random from theano.tensor import raw_random, reorder_row_elements
class RandomStreamsInstance(object): class RandomStreamsInstance(object):
"""RandomStreamsInstance""" """RandomStreamsInstance"""
...@@ -189,3 +189,10 @@ class RandomStreams(Component): ...@@ -189,3 +189,10 @@ class RandomStreams(Component):
""" """
return self.gen(raw_random.multinomial, *args, **kwargs) return self.gen(raw_random.multinomial, *args, **kwargs)
def shuffle_rows(self, input):
"""Return a variable with every row (rightmost index) shuffled"""
perm = self.permutation(input.ndim-1, input.shape[:-1], input.shape[-1])
shuffled = reorder_row_elements(input, perm)
return shuffled
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论