提交 efa2be5b authored 作者: James Bergstra's avatar James Bergstra

fixed #156

上级 f2903b1e
...@@ -1282,22 +1282,22 @@ class t_dot(unittest.TestCase): ...@@ -1282,22 +1282,22 @@ class t_dot(unittest.TestCase):
self.failUnless(tz.shape == nz.shape) self.failUnless(tz.shape == nz.shape)
self.failUnless(_approx_eq(nz, tz)) self.failUnless(_approx_eq(nz, tz))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2) #def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5)) #def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7)) #def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7)) #def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 ) #def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5)) def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7)) def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7)) #def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0) #def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6)) def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7)) def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7)) #def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0) #def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6)) #def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7)) #def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7)) #def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def not_aligned(self, x, y): def not_aligned(self, x, y):
z = dot(x,y) z = dot(x,y)
...@@ -1310,16 +1310,22 @@ class t_dot(unittest.TestCase): ...@@ -1310,16 +1310,22 @@ class t_dot(unittest.TestCase):
def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6)) def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6))
def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4)) def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4))
def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7)) #def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6)) def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7)) def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7))
def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8)) #def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8))
def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6)) #def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6))
def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7)) #def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8)) #def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
def test_grad(self): def test_grad(self):
#verify_grad(self, dot, [self.rand(2,3,4), self.rand(4)])
verify_grad(self, dot, [self.rand(2,3), self.rand(3,2)]) verify_grad(self, dot, [self.rand(2,3), self.rand(3,2)])
verify_grad(self, dot, [self.rand(2), self.rand(2,3)])
verify_grad(self, dot, [self.rand(3,2), self.rand(2)])
verify_grad(self, dot, [self.rand(2), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase): class t_gemm(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -1703,7 +1709,8 @@ if __name__ == '__main__': ...@@ -1703,7 +1709,8 @@ if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
else: else:
testcase = t_dot
suite = unittest.TestLoader() suite = unittest.TestLoader()
#suite = suite.loadTestsFromTestCase(T_subtensor) suite = suite.loadTestsFromTestCase(testcase)
suite = suite.loadTestsFromTestCase(T_Stack)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -912,37 +912,90 @@ def horizontal_stack(x, y): ...@@ -912,37 +912,90 @@ def horizontal_stack(x, y):
######################### #########################
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
"""
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(as_tensor, inputs) inputs = map(as_tensor, inputs)
if len(inputs) != 2:
raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self)
i_broadcastables = [input.type.broadcastable for input in inputs]
i_dtypes = [input.type.dtype for input in inputs]
bx, by = i_broadcastables numpy_semantics = 0
if len(bx) == 0: # x is a scalar if numpy_semantics:
bz = by #numpy defines dot for tensor pairs with any rank
if len(inputs) != 2:
raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self)
i_broadcastables = [input.type.broadcastable for input in inputs]
bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar
bz = by
else:
if len(by) >= 2: #y is a matrix or tensor
bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by)==1: #y is vector
bz = bx[:-1]
else: #y is a scalar
bz = bx
else: else:
if len(by) >= 2: #y is a matrix or tensor x, y = inputs
bz = bx[:-1] + by[:-2] + by[-1:] nx = x.type.ndim
elif len(by)==1: #y is vector ny = y.type.ndim
bz = bx[:-1]
else: #y is a scalar if nx not in (1,2): raise TypeError('not matrix or vector', x)
bz = bx if ny not in (1,2): raise TypeError('not matrix or vector', y)
o_broadcastables = [bz]
o_dtypes = [scal.upcast(*i_dtypes)] if nx == 2 and ny == 2:
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [tensor(t, b) for b, t in zip(o_broadcastables, o_dtypes)] elif nx == 1 and ny == 2:
bz = [y.type.broadcastable[1]]
elif nx == 2 and ny == 1:
bz = [x.type.broadcastable[0]]
else:
bz = []
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )): def perform(self, node, (x, y), (z, )):
z[0] = numpy.dot(x, y) z[0] = numpy.dot(x, y)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0:
return gz * y, gz * x
if x.type.ndim == 1 and y.type.ndim > 1:
return dot(gz, y.T), outer(x.T, gz)
if x.type.ndim > 1 and y.type.ndim == 1:
return outer(gz, y.T), dot(x.T, gz)
return dot(gz, y.T), dot(x.T, gz) return dot(gz, y.T), dot(x.T, gz)
def __str__(self): def __str__(self):
return "dot" return "dot"
dot = Dot() dot = Dot()
class Outer(Op):
""" Compute vector-vector outer product
"""
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
x, y = inputs
nx = x.type.ndim
ny = y.type.ndim
if nx != 1: raise TypeError('not vector', x)
if ny != 1: raise TypeError('not vector', y)
bz = [x.type.broadcastable[0], y.type.broadcastable[0]]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )):
z[0] = numpy.outer(x, y)
def grad(self, (x, y), (gz,)):
return dot(gz, y), dot(x, gz) #no transposing necessary
def __str__(self):
return "outer"
outer = Outer()
class Gemm(Op): class Gemm(Op):
E_rank = 'gemm only works for rank 2' E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument' E_scalar = 'gemm requires scalar argument'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论