提交 d7fb9402 authored 作者: ltoniazzi's avatar ltoniazzi 提交者: Brandon T. Willard

BUG: add tensor conversion to at.roll

上级 f4004b27
...@@ -2484,30 +2484,33 @@ def roll(x, shift, axis=None): ...@@ -2484,30 +2484,33 @@ def roll(x, shift, axis=None):
Output tensor, with the same shape as ``x``. Output tensor, with the same shape as ``x``.
""" """
_x = as_tensor_variable(x)
if axis is None: if axis is None:
if x.ndim > 1: if _x.ndim > 1:
y = x.flatten() y = _x.flatten()
return roll(y, shift, axis=0).reshape(x.shape) return roll(y, shift, axis=0).reshape(_x.shape)
else: else:
axis = 0 axis = 0
if axis < 0: if axis < 0:
axis += x.ndim axis += _x.ndim
# Shift may be larger than the size of the axis. If so, since the # Shift may be larger than the size of the axis. If so, since the
# roll operation is cyclic, we can take the shift modulo the size # roll operation is cyclic, we can take the shift modulo the size
# of the axis # of the axis
shift = shift % x.shape[axis] shift = shift % _x.shape[axis]
# A slice of all elements in a dimension ':' # A slice of all elements in a dimension ':'
allslice = slice(None) allslice = slice(None)
# List of slices describing the front half [:, :, shift:, :] # List of slices describing the front half [:, :, shift:, :]
front_slice = slice(-shift, None) front_slice = slice(-shift, None)
front_list = [allslice] * axis + [front_slice] + [allslice] * (x.ndim - axis - 1) front_list = [allslice] * axis + [front_slice] + [allslice] * (_x.ndim - axis - 1)
# List of slices describing the back half [:, :, :shift, :] # List of slices describing the back half [:, :, :shift, :]
end_slice = slice(0, -shift) end_slice = slice(0, -shift)
end_list = [allslice] * axis + [end_slice] + [allslice] * (x.ndim - axis - 1) end_list = [allslice] * axis + [end_slice] + [allslice] * (_x.ndim - axis - 1)
return join(axis, x.__getitem__(tuple(front_list)), x.__getitem__(tuple(end_list))) return join(
axis, _x.__getitem__(tuple(front_list)), _x.__getitem__(tuple(end_list))
)
def stack(*tensors, **kwargs): def stack(*tensors, **kwargs):
......
...@@ -1449,6 +1449,15 @@ class TestJoinAndSplit: ...@@ -1449,6 +1449,15 @@ class TestJoinAndSplit:
assert (out == want).all() assert (out == want).all()
# Pass a list to make sure `a` is converted to a
# TensorVariable by roll
a = [1, 2, 3, 4, 5, 6]
b = roll(a, get_shift(2))
want = np.array([5, 6, 1, 2, 3, 4])
out = aesara.function([], b)()
assert (out == want).all()
def test_stack_vector(self): def test_stack_vector(self):
a = self.shared(np.array([1, 2, 3], dtype=self.floatX)) a = self.shared(np.array([1, 2, 3], dtype=self.floatX))
b = as_tensor_variable(np.array([7, 8, 9], dtype=self.floatX)) b = as_tensor_variable(np.array([7, 8, 9], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论