fixed transpose as per bug #81

上级 f5801429
...@@ -159,41 +159,42 @@ class T_transpose(unittest.TestCase): ...@@ -159,41 +159,42 @@ class T_transpose(unittest.TestCase):
def test0(self): def test0(self):
n = astensor(numpy.ones(())) n = astensor(numpy.ones(()))
t = transpose(n) t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose) self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) f = Function([n], [t])
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == n.data.shape) self.failUnless(tval.shape == n.data.shape)
#test aliasing #test aliasing
tval += 55.0 tval += 55.0
self.failUnless(n.data == 56.0) self.failUnless(n.data == 1.0)
def test1(self): def test1(self):
n = astensor(numpy.ones(5)) n = astensor(numpy.ones(5))
t = transpose(n) t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose) self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) f = Function([n], [t])
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == n.data.shape) self.failUnless(tval.shape == n.data.shape)
#test aliasing #test aliasing
tval += 55.0 tval += 55.0
self.failUnless(n.data[0] == 56.0) self.failUnless(n.data[0] == 1.0)
def test2(self): def test2(self):
n = astensor(numpy.ones((5,3))) n = astensor(numpy.ones((5,3)))
t = transpose(n) t = transpose(n)
self.failUnless(t.owner.__class__ is Transpose) self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) f = Function([n], [t])
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == (3,5)) self.failUnless(tval.shape == (3,5))
#test aliasing #test aliasing
tval += 55.0 tval += 55.0
self.failUnless(n.data[0,0] == 56.0) self.failUnless(n.data[0,0] == 1.0)
def test3(self): def test3(self):
"""Test transpose of tensor, inplace version"""
n = astensor(numpy.ones((5,3,2))) n = astensor(numpy.ones((5,3,2)))
t = transpose(n) t = transpose_inplace(n)
self.failUnless(t.owner.__class__ is Transpose) self.failUnless(t.owner.__class__ is TransposeInplace)
f = Function([n], [t]) f = Function([n], [t])
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == (2,3,5)) self.failUnless(tval.shape == (2,3,5))
......
...@@ -59,9 +59,7 @@ class Tensor(BaseTensor): ...@@ -59,9 +59,7 @@ class Tensor(BaseTensor):
def __rpow__(self,other): return pow(other,self) def __rpow__(self,other): return pow(other,self)
#TRANSPOSE #TRANSPOSE
def __get_T(self): T = property(lambda self: transpose(self))
return tensor_copy(transpose(self))
T = property(__get_T)
#SLICING #SLICING
def __getitem__(self, item): return subtensor(self, item) def __getitem__(self, item): return subtensor(self, item)
...@@ -357,7 +355,7 @@ tensor_copy = gof.op.constructor(TensorCopy) ...@@ -357,7 +355,7 @@ tensor_copy = gof.op.constructor(TensorCopy)
# View Operations # View Operations
########################## ##########################
class Transpose(_Op, Viewer): class TransposeInplace(_Op, Viewer):
def view_map(self): def view_map(self):
return {self.out: [self.inputs[0]]} return {self.out: [self.inputs[0]]}
def propagate_broadcastable(self, x): def propagate_broadcastable(self, x):
...@@ -367,7 +365,7 @@ class Transpose(_Op, Viewer): ...@@ -367,7 +365,7 @@ class Transpose(_Op, Viewer):
def impl(self, x): def impl(self, x):
return x.T #numpy's transpose return x.T #numpy's transpose
def grad(self, x, gz): def grad(self, x, gz):
return transpose_copy(gz) return transpose(gz)
def c_impl(self, x, z): def c_impl(self, x, z):
return """ return """
...@@ -377,7 +375,9 @@ class Transpose(_Op, Viewer): ...@@ -377,7 +375,9 @@ class Transpose(_Op, Viewer):
} }
%(z)s = transposed; %(z)s = transposed;
""" """
transpose = gof.op.constructor(Transpose) transpose_inplace = gof.op.constructor(TransposeInplace)
def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs)
class Subtensor(Op, Viewer): class Subtensor(Op, Viewer):
nin = 2 nin = 2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论