fixed transpose as per bug #81

上级 f5801429
......@@ -159,41 +159,42 @@ class T_transpose(unittest.TestCase):
def test0(self):
n = astensor(numpy.ones(()))
t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose)
self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t])
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data == 56.0)
self.failUnless(n.data == 1.0)
def test1(self):
n = astensor(numpy.ones(5))
t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose)
self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t])
tval = f(n.data)
self.failUnless(tval.shape == n.data.shape)
#test aliasing
tval += 55.0
self.failUnless(n.data[0] == 56.0)
self.failUnless(n.data[0] == 1.0)
def test2(self):
n = astensor(numpy.ones((5,3)))
t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose)
self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t])
tval = f(n.data)
self.failUnless(tval.shape == (3,5))
#test aliasing
tval += 55.0
self.failUnless(n.data[0,0] == 56.0)
self.failUnless(n.data[0,0] == 1.0)
def test3(self):
"""Test transpose of tensor, inplace version"""
n = astensor(numpy.ones((5,3,2)))
t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose)
t = transpose_inplace(n)
self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t])
tval = f(n.data)
self.failUnless(tval.shape == (2,3,5))
......
......@@ -59,9 +59,7 @@ class Tensor(BaseTensor):
def __rpow__(self,other): return pow(other,self)
#TRANSPOSE
def __get_T(self):
return tensor_copy(transpose(self))
T = property(__get_T)
T = property(lambda self: transpose(self))
#SLICING
def __getitem__(self, item): return subtensor(self, item)
......@@ -357,7 +355,7 @@ tensor_copy = gof.op.constructor(TensorCopy)
# View Operations
##########################
class Transpose(_Op, Viewer):
class TransposeInplace(_Op, Viewer):
def view_map(self):
return {self.out: [self.inputs[0]]}
def propagate_broadcastable(self, x):
......@@ -367,7 +365,7 @@ class Transpose(_Op, Viewer):
def impl(self, x):
return x.T #numpy's transpose
def grad(self, x, gz):
return transpose_copy(gz)
return transpose(gz)
def c_impl(self, x, z):
return """
......@@ -377,7 +375,9 @@ class Transpose(_Op, Viewer):
}
%(z)s = transposed;
"""
transpose = gof.op.constructor(Transpose)
transpose_inplace = gof.op.constructor(TransposeInplace)
def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs)
class Subtensor(Op, Viewer):
nin = 2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论