提交 b410eac0 authored 作者: nouiz's avatar nouiz

Merge pull request #223 from dwf/fix_roll_default_axis

Fix default axis for tensor.roll
......@@ -4290,20 +4290,34 @@ def join(axis, *tensors):
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join),
printing.FunctionPrinter('join'))
def roll(x, shift, axis=0):
def roll(x, shift, axis=None):
"""
Convenience function to roll `TensorType`s along the given axis.
Syntax copies numpy.roll function
Parameters
----------
- x : a Tensor
- shift : int (symbolic or literal)
The number of places by which elements are shifted
- axis : int (symbolic or literal) (optional)
The axis along which elements are shifted.
Defaults to zero (deviation from numpy behavior)
x : tensor_like
Input tensor.
shift : int (symbolic or literal)
The number of places by which elements are shifted.
axis : int (symbolic or literal) (optional)
The axis along which elements are shifted. By default, the array
is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : tensor
Output tensor, with the same shape as `x`.
"""
if axis is None:
if x.ndim > 1:
y = x.flatten()
return roll(y, shift, axis=0).reshape(x.shape)
else:
axis = 0
# A slice of all elements in a dimension ':'
allslice = slice(None)
......@@ -4313,11 +4327,13 @@ def roll(x, shift, axis=0):
[allslice] * (x.ndim - axis - 1))
# List of slices describing the back half [:, :, :shift, :]
end_slice = slice(0, -shift)
end_list = [allslice]*axis + [end_slice] + [allslice]*(x.ndim-axis-1)
end_list = ([allslice] * axis + [end_slice] +
[allslice] * (x.ndim - axis - 1))
return join(axis,
Subtensor(front_list)(x),
Subtensor(end_list)(x))
@constructor
def shape_padleft(t, n_ones=1):
"""Reshape `t` by left-padding the shape with `n_ones` 1s
......
......@@ -2823,24 +2823,36 @@ class T_Join_and_Split(unittest.TestCase):
assert (out == want).all()
# Test simple 1D example with explicit 0 axis
b = roll(a, -1, 0)
want = numpy.array([2, 3, 4, 5, 6, 1])
out = theano.function([], b)()
assert (out == want).all()
# Test 2D example - ensure that behavior matches numpy.roll behavior
a = self.shared(numpy.arange(21).reshape((3, 7)))
b = roll(a, -2, 1)
want = numpy.arange(21).reshape((3, 7))
want = numpy.roll(want, -2, 1)
want = numpy.roll(a.get_value(borrow=True), -2, 1)
out = theano.function([], b)()
assert (out == want).all()
# Test rolling on axis 0
want = numpy.arange(21).reshape((3, 7))
want = numpy.roll(want, -2, 0)
want = numpy.roll(a.get_value(borrow=True), -2, 0)
b = roll(a, -2, 0)
out = theano.function([], b)()
assert (out == want).all()
# Test rolling on default axis with ndim > 1
want = numpy.roll(a.get_value(borrow=True), 2)
b = roll(a, 2)
out = theano.function([], b)()
assert (out == want).all()
def test_stack_vector(self):
a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论