提交 efa2be5b authored 作者: James Bergstra's avatar James Bergstra

fixed #156

上级 f2903b1e
......@@ -1282,22 +1282,22 @@ class t_dot(unittest.TestCase):
self.failUnless(tz.shape == nz.shape)
self.failUnless(_approx_eq(nz, tz))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
#def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
#def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
#def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
#def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
#def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
#def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
#def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
#def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
#def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
#def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
#def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
#def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def not_aligned(self, x, y):
z = dot(x,y)
......@@ -1310,16 +1310,22 @@ class t_dot(unittest.TestCase):
def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6))
def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4))
def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7))
#def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7))
def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8))
def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6))
def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
#def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8))
#def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6))
#def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
#def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
def test_grad(self):
#verify_grad(self, dot, [self.rand(2,3,4), self.rand(4)])
verify_grad(self, dot, [self.rand(2,3), self.rand(3,2)])
verify_grad(self, dot, [self.rand(2), self.rand(2,3)])
verify_grad(self, dot, [self.rand(3,2), self.rand(2)])
verify_grad(self, dot, [self.rand(2), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase):
def setUp(self):
......@@ -1703,7 +1709,8 @@ if __name__ == '__main__':
if 1:
unittest.main()
else:
testcase = t_dot
suite = unittest.TestLoader()
#suite = suite.loadTestsFromTestCase(T_subtensor)
suite = suite.loadTestsFromTestCase(T_Stack)
suite = suite.loadTestsFromTestCase(testcase)
unittest.TextTestRunner(verbosity=2).run(suite)
......@@ -912,37 +912,90 @@ def horizontal_stack(x, y):
#########################
class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
"""
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
if len(inputs) != 2:
raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self)
i_broadcastables = [input.type.broadcastable for input in inputs]
i_dtypes = [input.type.dtype for input in inputs]
bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar
bz = by
numpy_semantics = 0
if numpy_semantics:
#numpy defines dot for tensor pairs with any rank
if len(inputs) != 2:
raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self)
i_broadcastables = [input.type.broadcastable for input in inputs]
bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar
bz = by
else:
if len(by) >= 2: #y is a matrix or tensor
bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by)==1: #y is vector
bz = bx[:-1]
else: #y is a scalar
bz = bx
else:
if len(by) >= 2: #y is a matrix or tensor
bz = bx[:-1] + by[:-2] + by[-1:]
elif len(by)==1: #y is vector
bz = bx[:-1]
else: #y is a scalar
bz = bx
o_broadcastables = [bz]
o_dtypes = [scal.upcast(*i_dtypes)]
outputs = [tensor(t, b) for b, t in zip(o_broadcastables, o_dtypes)]
x, y = inputs
nx = x.type.ndim
ny = y.type.ndim
if nx not in (1,2): raise TypeError('not matrix or vector', x)
if ny not in (1,2): raise TypeError('not matrix or vector', y)
if nx == 2 and ny == 2:
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
elif nx == 1 and ny == 2:
bz = [y.type.broadcastable[1]]
elif nx == 2 and ny == 1:
bz = [x.type.broadcastable[0]]
else:
bz = []
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )):
z[0] = numpy.dot(x, y)
def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0:
return gz * y, gz * x
if x.type.ndim == 1 and y.type.ndim > 1:
return dot(gz, y.T), outer(x.T, gz)
if x.type.ndim > 1 and y.type.ndim == 1:
return outer(gz, y.T), dot(x.T, gz)
return dot(gz, y.T), dot(x.T, gz)
def __str__(self):
return "dot"
dot = Dot()
class Outer(Op):
""" Compute vector-vector outer product
"""
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
x, y = inputs
nx = x.type.ndim
ny = y.type.ndim
if nx != 1: raise TypeError('not vector', x)
if ny != 1: raise TypeError('not vector', y)
bz = [x.type.broadcastable[0], y.type.broadcastable[0]]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )):
z[0] = numpy.outer(x, y)
def grad(self, (x, y), (gz,)):
return dot(gz, y), dot(x, gz) #no transposing necessary
def __str__(self):
return "outer"
outer = Outer()
class Gemm(Op):
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论