提交 725c5664 authored 作者: Frederic's avatar Frederic

Set a name to transposed name when the input have one.

上级 2f223eff
...@@ -3508,7 +3508,10 @@ def transpose(x, axes=None): ...@@ -3508,7 +3508,10 @@ def transpose(x, axes=None):
""" """
if axes is None: if axes is None:
axes = range((x.ndim - 1), -1, -1) axes = range((x.ndim - 1), -1, -1)
return DimShuffle(x.broadcastable, axes, inplace=False)(x) ret = DimShuffle(x.broadcastable, axes, inplace=False)(x)
if x.name and axes == range((x.ndim - 1), -1, -1):
ret.name = x.name + '.T'
return ret
class AdvancedIndexingError(TypeError): class AdvancedIndexingError(TypeError):
......
...@@ -5841,9 +5841,9 @@ class test_numpy_assumptions(unittest.TestCase): ...@@ -5841,9 +5841,9 @@ class test_numpy_assumptions(unittest.TestCase):
def test_transpose(): def test_transpose():
x1 = tensor.dvector() x1 = tensor.dvector('x1')
x2 = tensor.dmatrix() x2 = tensor.dmatrix('x2')
x3 = tensor.dtensor3() x3 = tensor.dtensor3('x3')
x1v = numpy.arange(24) x1v = numpy.arange(24)
x2v = numpy.arange(24).reshape(2, 12) x2v = numpy.arange(24).reshape(2, 12)
...@@ -5881,6 +5881,12 @@ def test_transpose(): ...@@ -5881,6 +5881,12 @@ def test_transpose():
assert numpy.all(t2d == numpy.transpose(x2v, [0, 1])) assert numpy.all(t2d == numpy.transpose(x2v, [0, 1]))
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1])) assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
# Check that we create a name.
assert tensor.transpose(x1).name == 'x1.T'
assert tensor.transpose(x2).name == 'x2.T'
assert tensor.transpose(x3).name == 'x3.T'
assert tensor.transpose(tensor.dmatrix()).name == None
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论