提交 384b9e66 authored 作者: Tanjay94's avatar Tanjay94

Added swapaxes function.

上级 d82eb54a
...@@ -5026,3 +5026,12 @@ def ptp(a, axis=None): ...@@ -5026,3 +5026,12 @@ def ptp(a, axis=None):
def power(x, y): def power(x, y):
return x**y return x**y
def swapaxes(y,axis1,axis2):
"swap axes of inputted tensor"
y = as_tensor_variable(y)
ndim = y.ndim
li = range(0,ndim)
li[axis1], li[axis2] = li[axis2], li[axis1]
return y.dimshuffle(li)
...@@ -6898,6 +6898,22 @@ if __name__ == '__main__': ...@@ -6898,6 +6898,22 @@ if __name__ == '__main__':
t.setUp() t.setUp()
t.test_infer_shape() t.test_infer_shape()
class T_swapaxesbadinput(unittest.TestCase):
def test_no_dimensional_input(self):
self.assertRaises(IndexError, Axes.swapaxes, 2,0,1)
def test_unidimensional_input(self):
self.assertRaises(IndexError, Axes.swapaxes, [2,1],0,1)
def test_not_enough_dimension(self):
self.assertRaises(IndexError, Axes.swapaxes, [[2,1],[3,4]], 3, 4)
def test_doubleswap(self):
y = matrix()
n = Axes.swapaxes(y,0,1)
f = function([y], n)
testMatrix = [[2,1],[3,4]]
self.assertTrue(numpy.array_equal(testMatrix,f(f(testMatrix))))
class T_Power(): class T_Power():
def test_numpy_compare(self): def test_numpy_compare(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论