提交 bcd856b4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4798 from hexahedria/roll_modulo

Use shift modulo axis size to handle large shifts in theano.tensor.roll
......@@ -4047,6 +4047,11 @@ def roll(x, shift, axis=None):
else:
axis = 0
# Shift may be larger than the size of the axis. If so, since the
# roll operation is cyclic, we can take the shift modulo the size
# of the axis
shift = shift % x.shape[axis]
# A slice of all elements in a dimension ':'
allslice = slice(None)
# List of slices describing the front half [:, :, shift:, :]
......
......@@ -3736,6 +3736,22 @@ class T_Join_and_Split(unittest.TestCase):
assert (out == want).all()
# Test rolling on axis 0 with a positive shift that is
# larger than axis size
want = numpy.roll(a.get_value(borrow=True), 4, 0)
b = roll(a, get_shift(4), 0)
out = theano.function([], b)()
assert (out == want).all()
# Test rolling on axis 0 with a negative shift that is
# larger than axis size
want = numpy.roll(a.get_value(borrow=True), -4, 0)
b = roll(a, get_shift(-4), 0)
out = theano.function([], b)()
assert (out == want).all()
def test_stack_vector(self):
a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX))
b = as_tensor_variable(numpy.array([7, 8, 9], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论