提交 d5e160a6 authored 作者: Tanjay94's avatar Tanjay94

Ajusted the code format and added an interface to swapaxes in the TensorVariable object.

上级 f7daa82c
...@@ -5032,6 +5032,6 @@ def swapaxes(y,axis1,axis2): ...@@ -5032,6 +5032,6 @@ def swapaxes(y,axis1,axis2):
"swap axes of inputted tensor" "swap axes of inputted tensor"
y = as_tensor_variable(y) y = as_tensor_variable(y)
ndim = y.ndim ndim = y.ndim
li = range(0,ndim) li = range(0, ndim)
li[axis1], li[axis2] = li[axis2], li[axis1] li[axis1], li[axis2] = li[axis2], li[axis1]
return y.dimshuffle(li) return y.dimshuffle(li)
...@@ -6901,21 +6901,21 @@ if __name__ == '__main__': ...@@ -6901,21 +6901,21 @@ if __name__ == '__main__':
class T_swapaxes(unittest.TestCase): class T_swapaxes(unittest.TestCase):
def test_no_dimensional_input(self): def test_no_dimensional_input(self):
self.assertRaises(IndexError, swapaxes, 2,0,1) self.assertRaises(IndexError, swapaxes, 2, 0, 1)
def test_unidimensional_input(self): def test_unidimensional_input(self):
self.assertRaises(IndexError, swapaxes, [2,1],0,1) self.assertRaises(IndexError, swapaxes, [2, 1], 0, 1)
def test_not_enough_dimension(self): def test_not_enough_dimension(self):
self.assertRaises(IndexError, swapaxes, [[2,1],[3,4]], 3, 4) self.assertRaises(IndexError, swapaxes, [[2, 1], [3, 4]], 3, 4)
def test_doubleswap(self): def test_doubleswap(self):
y = matrix() y = matrix()
n = swapaxes(y,0,1) n = swapaxes(y, 0, 1)
f = function([y], n) f = function([y], n)
testMatrix = [[2,1],[3,4]] testMatrix = [[2, 1], [3, 4]]
self.assertTrue(numpy.array_equal(testMatrix,f(f(testMatrix)))) self.assertTrue(numpy.array_equal(testMatrix, f(f(testMatrix))))
class T_Power(): class T_Power():
def test_numpy_compare(self): def test_numpy_compare(self):
......
...@@ -699,4 +699,11 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -699,4 +699,11 @@ class TensorConstant(_tensor_py_operators, Constant):
copy.deepcopy(self.data, memo), copy.deepcopy(self.data, memo),
copy.deepcopy(self.name, memo)) copy.deepcopy(self.name, memo))
def swapaxes(self, axis1, axis2)
"""Return 'tensor.swapaxes(self, axis1, axis2)
If a matrix is provided with the right axes, its transpose will be returned.
"""
return theano.tensor.basic.swapaxes(self, axis1, axis2)
TensorType.Constant = TensorConstant TensorType.Constant = TensorConstant
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论