提交 b410eac0 authored 作者: nouiz's avatar nouiz

Merge pull request #223 from dwf/fix_roll_default_axis

Fix default axis for tensor.roll
...@@ -4290,20 +4290,34 @@ def join(axis, *tensors): ...@@ -4290,20 +4290,34 @@ def join(axis, *tensors):
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join), pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join),
printing.FunctionPrinter('join')) printing.FunctionPrinter('join'))
def roll(x, shift, axis=0):
def roll(x, shift, axis=None):
""" """
Convenience function to roll `TensorType`s along the given axis. Convenience function to roll `TensorType`s along the given axis.
Syntax copies numpy.roll function Syntax copies numpy.roll function
Parameters Parameters
---------- ----------
- x : a Tensor x : tensor_like
- shift : int (symbolic or literal) Input tensor.
The number of places by which elements are shifted shift : int (symbolic or literal)
- axis : int (symbolic or literal) (optional) The number of places by which elements are shifted.
The axis along which elements are shifted. axis : int (symbolic or literal) (optional)
Defaults to zero (deviation from numpy behavior) The axis along which elements are shifted. By default, the array
is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : tensor
Output tensor, with the same shape as `x`.
""" """
if axis is None:
if x.ndim > 1:
y = x.flatten()
return roll(y, shift, axis=0).reshape(x.shape)
else:
axis = 0
# A slice of all elements in a dimension ':' # A slice of all elements in a dimension ':'
allslice = slice(None) allslice = slice(None)
...@@ -4313,11 +4327,13 @@ def roll(x, shift, axis=0): ...@@ -4313,11 +4327,13 @@ def roll(x, shift, axis=0):
[allslice] * (x.ndim - axis - 1)) [allslice] * (x.ndim - axis - 1))
# List of slices describing the back half [:, :, :shift, :] # List of slices describing the back half [:, :, :shift, :]
end_slice = slice(0, -shift) end_slice = slice(0, -shift)
end_list = [allslice]*axis + [end_slice] + [allslice]*(x.ndim-axis-1) end_list = ([allslice] * axis + [end_slice] +
[allslice] * (x.ndim - axis - 1))
return join(axis, return join(axis,
Subtensor(front_list)(x), Subtensor(front_list)(x),
Subtensor(end_list)(x)) Subtensor(end_list)(x))
@constructor @constructor
def shape_padleft(t, n_ones=1): def shape_padleft(t, n_ones=1):
"""Reshape `t` by left-padding the shape with `n_ones` 1s """Reshape `t` by left-padding the shape with `n_ones` 1s
......
...@@ -2823,24 +2823,36 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2823,24 +2823,36 @@ class T_Join_and_Split(unittest.TestCase):
assert (out == want).all() assert (out == want).all()
# Test simple 1D example with explicit 0 axis
b = roll(a, -1, 0)
want = numpy.array([2, 3, 4, 5, 6, 1])
out = theano.function([], b)()
assert (out == want).all()
# Test 2D example - ensure that behavior matches numpy.roll behavior # Test 2D example - ensure that behavior matches numpy.roll behavior
a = self.shared(numpy.arange(21).reshape((3, 7))) a = self.shared(numpy.arange(21).reshape((3, 7)))
b = roll(a, -2, 1) b = roll(a, -2, 1)
want = numpy.arange(21).reshape((3, 7)) want = numpy.roll(a.get_value(borrow=True), -2, 1)
want = numpy.roll(want, -2, 1)
out = theano.function([], b)() out = theano.function([], b)()
assert (out == want).all() assert (out == want).all()
# Test rolling on axis 0 # Test rolling on axis 0
want = numpy.arange(21).reshape((3, 7)) want = numpy.roll(a.get_value(borrow=True), -2, 0)
want = numpy.roll(want, -2, 0)
b = roll(a, -2, 0) b = roll(a, -2, 0)
out = theano.function([], b)() out = theano.function([], b)()
assert (out == want).all() assert (out == want).all()
# Test rolling on default axis with ndim > 1
want = numpy.roll(a.get_value(borrow=True), 2)
b = roll(a, 2)
out = theano.function([], b)()
assert (out == want).all()
def test_stack_vector(self): def test_stack_vector(self):
a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX)) a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论