提交 6384da65 authored 作者: nouiz's avatar nouiz

Merge pull request #341 from jaberg/master

transpose method and tests
......@@ -1237,7 +1237,31 @@ class _tensor_py_operators:
#TRANSPOSE
T = property(lambda self: transpose(self))
def transpose(self, *axes):
"""
Return `tensor.transpose(self, axes)`
or `tensor.transpose(self, axes[0])`
If only one `axes` argument is provided and it is iterable, then it is
assumed to be the entire axes tuple, and passed intact to
tensor.transpose.
"""
if len(axes) == 0:
return transpose(self)
try:
iter(axes[0])
iterable = True
except TypeError:
iterable = False
if len(axes) == 1 and iterable:
return transpose(self, axes[0])
else:
return transpose(self, axes)
shape = property(lambda self: shape(self))
size = property(lambda self: prod(self.shape))
# We can't implement __len__ to provide a better error message.
......@@ -1347,18 +1371,23 @@ class _tensor_py_operators:
# CONVENIENT ACCESS TO TYPE PROPERTIES
ndim = property(lambda self: self.type.ndim)
"""The rank of this tensor."""
broadcastable = property(lambda self: self.type.broadcastable)
"""The broadcastable signature of this tensor.
See :doc:`broadcasting` for details.
"""
dtype = property(lambda self: self.type.dtype)
""" The dtype of this tensor. """
#extra pseudo-operator symbols
def __dot__(left, right): return dot(left, right)
def __rdot__(right, left): return dot(left, right)
def __dot__(left, right):
return dot(left, right)
def __rdot__(right, left):
return dot(left, right)
def sum(self, axis=None):
return elemwise.Sum(axis)(self)
......@@ -1390,7 +1419,6 @@ class _tensor_py_operators:
#TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000
def get_constant_value(self):
return get_constant_value(self)
......@@ -2934,10 +2962,17 @@ def get_canonical_form_slice(theslice, length):
return value, 1
def transpose(x, **kwargs):
dims = range(x.ndim-1, -1, -1)
def transpose(x, axes=None):
"""
Reorder the dimensions of x. (Default: reverse them)
This is a macro around dimshuffle that matches the numpy.transpose
function.
return DimShuffle(x.broadcastable, dims, inplace=False)(x)
"""
if axes is None:
axes = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, axes, inplace=False)(x)
class AdvancedIndexingError(TypeError):
......
......@@ -5099,6 +5099,7 @@ def test_mod():
):
assert fn(a,b) == a%b, (a,)
def test_mod_compile():
"""
This test generate an Elemwise of Composite as:
......@@ -5122,6 +5123,7 @@ def test_mod_compile():
f = theano.function([x,y],out)
def test_unalign():
if config.floatX == 'float64':
dtype="b1,f8"
......@@ -5153,6 +5155,7 @@ def test_unalign():
if not should_raise:
raise Exception("Theano raised an exception when none was expected")
def test_dimshuffle_duplicate():
x = tensor.vector()
......@@ -5167,7 +5170,6 @@ def test_dimshuffle_duplicate():
assert success
class T_get_constant_value(unittest.TestCase):
def test_get_constant_value(self):
a = tensor.stack(1,2,3)
assert get_constant_value(a[0])==1
......@@ -5202,6 +5204,7 @@ class T_get_constant_value(unittest.TestCase):
for j in range(c.value.shape[1]):
assert get_constant_value(c[i,j]) == c.value[i,j]
class T_as_tensor_variable(unittest.TestCase):
"""
We test that ticket #649 stay fixed.
......@@ -5231,7 +5234,6 @@ class test_complex_mod(unittest.TestCase):
class test_size(unittest.TestCase):
"""
Ensure the `size` attribute of tensors behaves as in numpy.
"""
......@@ -5259,7 +5261,6 @@ class test_size(unittest.TestCase):
class test_numpy_assumptions(unittest.TestCase):
"""
Verify that some assumptions Theano makes on Numpy's behavior still hold.
"""
......@@ -5290,6 +5291,49 @@ class test_numpy_assumptions(unittest.TestCase):
assert (dtype1 == dtype2) == (str(dtype1) == str(dtype2))
def test_transpose():
x1 = tensor.dvector()
x2 = tensor.dmatrix()
x3 = tensor.dtensor3()
x1v = numpy.arange(24)
x2v = numpy.arange(24).reshape(2, 12)
x3v = numpy.arange(24).reshape(2, 3, 4)
f = theano.function([x1, x2, x3], [
tensor.transpose(x1),
tensor.transpose(x2),
tensor.transpose(x3),
x1.transpose(),
x2.transpose(),
x3.transpose(),
x2.transpose(0, 1),
x3.transpose((0, 2, 1)),
tensor.transpose(x2, [0, 1]),
tensor.transpose(x3, [0, 2, 1]),
])
t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v)
assert t1.shape == numpy.transpose(x1v).shape
assert t2.shape == numpy.transpose(x2v).shape
assert t3.shape == numpy.transpose(x3v).shape
assert numpy.all(t1 == numpy.transpose(x1v))
assert numpy.all(t2 == numpy.transpose(x2v))
assert numpy.all(t3 == numpy.transpose(x3v))
assert numpy.all(t1b == x1v.transpose())
assert numpy.all(t2b == x2v.transpose())
assert numpy.all(t3b == x3v.transpose())
assert t2c.shape == (2, 12)
assert t3c.shape == (2, 4, 3)
assert numpy.all(t2c == x2v.transpose([0, 1]))
assert numpy.all(t3c == x3v.transpose([0, 2, 1]))
assert t2d.shape == (2, 12)
assert t3d.shape == (2, 4, 3)
assert numpy.all(t2d == numpy.transpose(x2v, [0, 1]))
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
if __name__ == '__main__':
if 0:
unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论