提交 e73f233e authored 作者: hexahedria's avatar hexahedria

Use shift modulo axis size to handle large shifts in theano.tensor.roll

上级 81e52783
...@@ -4047,6 +4047,11 @@ def roll(x, shift, axis=None): ...@@ -4047,6 +4047,11 @@ def roll(x, shift, axis=None):
else: else:
axis = 0 axis = 0
# Shift may be larger than the size of the axis. If so, since the
# roll operation is cyclic, we can take the shift modulo the size
# of the axis
shift = shift % x.shape[axis]
# A slice of all elements in a dimension ':' # A slice of all elements in a dimension ':'
allslice = slice(None) allslice = slice(None)
# List of slices describing the front half [:, :, shift:, :] # List of slices describing the front half [:, :, shift:, :]
......
...@@ -3736,6 +3736,22 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3736,6 +3736,22 @@ class T_Join_and_Split(unittest.TestCase):
assert (out == want).all() assert (out == want).all()
# Test rolling on axis 0 with a positive shift that is
# larger than axis size
want = numpy.roll(a.get_value(borrow=True), 4, 0)
b = roll(a, get_shift(4), 0)
out = theano.function([], b)()
assert (out == want).all()
# Test rolling on axis 0 with a negative shift that is
# larger than axis size
want = numpy.roll(a.get_value(borrow=True), -4, 0)
b = roll(a, get_shift(-4), 0)
out = theano.function([], b)()
assert (out == want).all()
def test_stack_vector(self): def test_stack_vector(self):
a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX)) a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX))
b = as_tensor_variable(numpy.array([7, 8, 9], dtype=self.floatX)) b = as_tensor_variable(numpy.array([7, 8, 9], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论