提交 4e64abe0 authored 作者: James Bergstra's avatar James Bergstra

added a transpose method for TensorVariables and a transpose test

上级 44f47751
...@@ -1236,7 +1236,31 @@ class _tensor_py_operators: ...@@ -1236,7 +1236,31 @@ class _tensor_py_operators:
#TRANSPOSE #TRANSPOSE
T = property(lambda self: transpose(self)) T = property(lambda self: transpose(self))
def transpose(self, *axes):
"""
Return `tensor.transpose(self, axes)`
or `tensor.transpose(self, axes[0])`
If only one `axes` argument is provided and it is iterable, then it is
assumed to be the entire axes tuple, and passed intact to
tensor.transpose.
"""
if len(axes) == 0:
return transpose(self)
try:
iter(axes[0])
iterable = True
except TypeError:
iterable = False
if len(axes) == 1 and iterable:
return transpose(self, axes[0])
else:
return transpose(self, axes)
shape = property(lambda self: shape(self)) shape = property(lambda self: shape(self))
size = property(lambda self: prod(self.shape)) size = property(lambda self: prod(self.shape))
# We can't implement __len__ to provide a better error message. # We can't implement __len__ to provide a better error message.
...@@ -1346,18 +1370,23 @@ class _tensor_py_operators: ...@@ -1346,18 +1370,23 @@ class _tensor_py_operators:
# CONVENIENT ACCESS TO TYPE PROPERTIES # CONVENIENT ACCESS TO TYPE PROPERTIES
ndim = property(lambda self: self.type.ndim) ndim = property(lambda self: self.type.ndim)
"""The rank of this tensor.""" """The rank of this tensor."""
broadcastable = property(lambda self: self.type.broadcastable) broadcastable = property(lambda self: self.type.broadcastable)
"""The broadcastable signature of this tensor. """The broadcastable signature of this tensor.
See :doc:`broadcasting` for details. See :doc:`broadcasting` for details.
""" """
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
""" The dtype of this tensor. """ """ The dtype of this tensor. """
#extra pseudo-operator symbols #extra pseudo-operator symbols
def __dot__(left, right): return dot(left, right) def __dot__(left, right):
def __rdot__(right, left): return dot(left, right) return dot(left, right)
def __rdot__(right, left):
return dot(left, right)
def sum(self, axis=None): def sum(self, axis=None):
return elemwise.Sum(axis)(self) return elemwise.Sum(axis)(self)
...@@ -1389,7 +1418,6 @@ class _tensor_py_operators: ...@@ -1389,7 +1418,6 @@ class _tensor_py_operators:
#TO TRUMP NUMPY OPERATORS #TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000 __array_priority__ = 1000
def get_constant_value(self): def get_constant_value(self):
return get_constant_value(self) return get_constant_value(self)
...@@ -2933,10 +2961,17 @@ def get_canonical_form_slice(theslice, length): ...@@ -2933,10 +2961,17 @@ def get_canonical_form_slice(theslice, length):
return value, 1 return value, 1
def transpose(x, **kwargs): def transpose(x, axes=None):
dims = range(x.ndim-1, -1, -1) """
Reorder the dimensions of x. (Default: reverse them)
This is a macro around dimshuffle that matches the numpy.transpose
function.
return DimShuffle(x.broadcastable, dims, inplace=False)(x) """
if axes is None:
axes = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, axes, inplace=False)(x)
class AdvancedIndexingError(TypeError): class AdvancedIndexingError(TypeError):
......
...@@ -5083,6 +5083,7 @@ def test_mod(): ...@@ -5083,6 +5083,7 @@ def test_mod():
): ):
assert fn(a,b) == a%b, (a,) assert fn(a,b) == a%b, (a,)
def test_mod_compile(): def test_mod_compile():
""" """
This test generate an Elemwise of Composite as: This test generate an Elemwise of Composite as:
...@@ -5106,6 +5107,7 @@ def test_mod_compile(): ...@@ -5106,6 +5107,7 @@ def test_mod_compile():
f = theano.function([x,y],out) f = theano.function([x,y],out)
def test_unalign(): def test_unalign():
if config.floatX == 'float64': if config.floatX == 'float64':
dtype="b1,f8" dtype="b1,f8"
...@@ -5137,6 +5139,7 @@ def test_unalign(): ...@@ -5137,6 +5139,7 @@ def test_unalign():
if not should_raise: if not should_raise:
raise Exception("Theano raised an exception when none was expected") raise Exception("Theano raised an exception when none was expected")
def test_dimshuffle_duplicate(): def test_dimshuffle_duplicate():
x = tensor.vector() x = tensor.vector()
...@@ -5151,7 +5154,6 @@ def test_dimshuffle_duplicate(): ...@@ -5151,7 +5154,6 @@ def test_dimshuffle_duplicate():
assert success assert success
class T_get_constant_value(unittest.TestCase): class T_get_constant_value(unittest.TestCase):
def test_get_constant_value(self): def test_get_constant_value(self):
a = tensor.stack(1,2,3) a = tensor.stack(1,2,3)
assert get_constant_value(a[0])==1 assert get_constant_value(a[0])==1
...@@ -5186,6 +5188,7 @@ class T_get_constant_value(unittest.TestCase): ...@@ -5186,6 +5188,7 @@ class T_get_constant_value(unittest.TestCase):
for j in range(c.value.shape[1]): for j in range(c.value.shape[1]):
assert get_constant_value(c[i,j]) == c.value[i,j] assert get_constant_value(c[i,j]) == c.value[i,j]
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
We test that ticket #649 stay fixed. We test that ticket #649 stay fixed.
...@@ -5215,7 +5218,6 @@ class test_complex_mod(unittest.TestCase): ...@@ -5215,7 +5218,6 @@ class test_complex_mod(unittest.TestCase):
class test_size(unittest.TestCase): class test_size(unittest.TestCase):
""" """
Ensure the `size` attribute of tensors behaves as in numpy. Ensure the `size` attribute of tensors behaves as in numpy.
""" """
...@@ -5243,7 +5245,6 @@ class test_size(unittest.TestCase): ...@@ -5243,7 +5245,6 @@ class test_size(unittest.TestCase):
class test_numpy_assumptions(unittest.TestCase): class test_numpy_assumptions(unittest.TestCase):
""" """
Verify that some assumptions Theano makes on Numpy's behavior still hold. Verify that some assumptions Theano makes on Numpy's behavior still hold.
""" """
...@@ -5274,6 +5275,43 @@ class test_numpy_assumptions(unittest.TestCase): ...@@ -5274,6 +5275,43 @@ class test_numpy_assumptions(unittest.TestCase):
assert (dtype1 == dtype2) == (str(dtype1) == str(dtype2)) assert (dtype1 == dtype2) == (str(dtype1) == str(dtype2))
def test_transpose():
x1 = tensor.dvector()
x2 = tensor.dmatrix()
x3 = tensor.dtensor3()
x1v = numpy.arange(24)
x2v = numpy.arange(24).reshape(2, 12)
x3v = numpy.arange(24).reshape(2, 3, 4)
f = theano.function([x1, x2, x3], [
tensor.transpose(x1),
tensor.transpose(x2),
tensor.transpose(x3),
x1.transpose(),
x2.transpose(),
x3.transpose(),
x2.transpose(0, 1),
x3.transpose((0, 2, 1)),
])
t1, t2, t3, t1b, t2b, t3b, t2c, t3c = f(x1v, x2v, x3v)
assert t1.shape == numpy.transpose(x1v).shape
assert t2.shape == numpy.transpose(x2v).shape
assert t3.shape == numpy.transpose(x3v).shape
assert numpy.all(t1 == numpy.transpose(x1v))
assert numpy.all(t2 == numpy.transpose(x2v))
assert numpy.all(t3 == numpy.transpose(x3v))
assert numpy.all(t1b == x1v.transpose())
assert numpy.all(t2b == x2v.transpose())
assert numpy.all(t3b == x3v.transpose())
assert t2c.shape == (2, 12)
assert t3c.shape == (2, 4, 3)
assert numpy.all(t2c == x2v.transpose([0, 1]))
assert numpy.all(t3c == x3v.transpose([0, 2, 1]))
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论