提交 22ef48cb authored 作者: David Warde-Farley's avatar David Warde-Farley

Add support/test for NumPy-style default axis.

Elegantly solved with a recursive call.
上级 7320ff5e
...@@ -4291,7 +4291,7 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join), ...@@ -4291,7 +4291,7 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join),
printing.FunctionPrinter('join')) printing.FunctionPrinter('join'))
def roll(x, shift, axis=0): def roll(x, shift, axis=None):
""" """
Convenience function to roll `TensorType`s along the given axis. Convenience function to roll `TensorType`s along the given axis.
Syntax copies numpy.roll function Syntax copies numpy.roll function
...@@ -4304,7 +4304,11 @@ def roll(x, shift, axis=0): ...@@ -4304,7 +4304,11 @@ def roll(x, shift, axis=0):
- axis : int (symbolic or literal) (optional) - axis : int (symbolic or literal) (optional)
The axis along which elements are shifted. The axis along which elements are shifted.
Defaults to zero (deviation from numpy behavior) Defaults to zero (deviation from numpy behavior)
Defaults to flattening first, rolling, and then reshaping.
""" """
if axis is None:
y = x.flatten()
return roll(y, shift, axis=0).reshape(x.shape)
# A slice of all elements in a dimension ':' # A slice of all elements in a dimension ':'
allslice = slice(None) allslice = slice(None)
......
...@@ -2841,6 +2841,12 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2841,6 +2841,12 @@ class T_Join_and_Split(unittest.TestCase):
assert (out == want).all() assert (out == want).all()
# Test rolling on default axis with ndim > 1
want = numpy.roll(numpy.arange(21).reshape((3, 7)), 2)
b = roll(a, 2)
out = theano.function([], b)()
assert (out == want).all()
def test_stack_vector(self): def test_stack_vector(self):
a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX)) a = self.shared(numpy.array([1, 2, 3], dtype=self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论