提交 e0d81475 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -52,6 +52,7 @@ def __src_version__(): ...@@ -52,6 +52,7 @@ def __src_version__():
# #
if not hasattr(__src_version__, 'rval'):
#print 'name:', __name__ #print 'name:', __name__
location = _imp.find_module(__name__)[1] location = _imp.find_module(__name__)[1]
#print 'location:', location #print 'location:', location
...@@ -80,5 +81,7 @@ def __src_version__(): ...@@ -80,5 +81,7 @@ def __src_version__():
if len(tokens) == 2: if len(tokens) == 2:
assert tokens[1] == 'tip\n' assert tokens[1] == 'tip\n'
return tokens[0] __src_version__.rval = tokens[0]
return __src_version__.rval
...@@ -1282,22 +1282,22 @@ class t_dot(unittest.TestCase): ...@@ -1282,22 +1282,22 @@ class t_dot(unittest.TestCase):
self.failUnless(tz.shape == nz.shape) self.failUnless(tz.shape == nz.shape)
self.failUnless(_approx_eq(nz, tz)) self.failUnless(_approx_eq(nz, tz))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2) #def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5)) #def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7)) #def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7)) #def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 ) #def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5)) def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7)) def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7)) #def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0) #def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6)) def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7)) def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7)) #def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0) #def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6)) #def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7)) #def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7)) #def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def not_aligned(self, x, y): def not_aligned(self, x, y):
z = dot(x,y) z = dot(x,y)
...@@ -1310,16 +1310,22 @@ class t_dot(unittest.TestCase): ...@@ -1310,16 +1310,22 @@ class t_dot(unittest.TestCase):
def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6)) def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6))
def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4)) def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4))
def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7)) #def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6)) def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7)) def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7))
def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8)) #def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8))
def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6)) #def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6))
def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7)) #def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8)) #def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
def test_grad(self): def test_grad(self):
#verify_grad(self, dot, [self.rand(2,3,4), self.rand(4)])
verify_grad(self, dot, [self.rand(2,3), self.rand(3,2)]) verify_grad(self, dot, [self.rand(2,3), self.rand(3,2)])
verify_grad(self, dot, [self.rand(2), self.rand(2,3)])
verify_grad(self, dot, [self.rand(3,2), self.rand(2)])
verify_grad(self, dot, [self.rand(2), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase): class t_gemm(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -1703,7 +1709,8 @@ if __name__ == '__main__': ...@@ -1703,7 +1709,8 @@ if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
else: else:
testcase = t_dot
suite = unittest.TestLoader() suite = unittest.TestLoader()
#suite = suite.loadTestsFromTestCase(T_subtensor) suite = suite.loadTestsFromTestCase(testcase)
suite = suite.loadTestsFromTestCase(T_Stack)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -69,7 +69,7 @@ def numpy_wrapper(f): ...@@ -69,7 +69,7 @@ def numpy_wrapper(f):
for thunk in thunks: for thunk in thunks:
for output in thunk.outputs: for output in thunk.outputs:
if hasattr(output, 'dtype'): if hasattr(output, 'dtype'):
if f(output)): if f(output):
raise Exception('uh oh', (thunk, output)) raise Exception('uh oh', (thunk, output))
return wrapper return wrapper
......
...@@ -912,13 +912,18 @@ def horizontal_stack(x, y): ...@@ -912,13 +912,18 @@ def horizontal_stack(x, y):
######################### #########################
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
"""
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(as_tensor, inputs) inputs = map(as_tensor, inputs)
numpy_semantics = 0
if numpy_semantics:
#numpy defines dot for tensor pairs with any rank
if len(inputs) != 2: if len(inputs) != 2:
raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self) raise TypeError("Wrong number of inputs for %s (got %i, expected 2)" % self)
i_broadcastables = [input.type.broadcastable for input in inputs] i_broadcastables = [input.type.broadcastable for input in inputs]
i_dtypes = [input.type.dtype for input in inputs]
bx, by = i_broadcastables bx, by = i_broadcastables
if len(bx) == 0: # x is a scalar if len(bx) == 0: # x is a scalar
bz = by bz = by
...@@ -929,20 +934,68 @@ class Dot(Op): ...@@ -929,20 +934,68 @@ class Dot(Op):
bz = bx[:-1] bz = bx[:-1]
else: #y is a scalar else: #y is a scalar
bz = bx bz = bx
o_broadcastables = [bz] else:
o_dtypes = [scal.upcast(*i_dtypes)] x, y = inputs
nx = x.type.ndim
outputs = [tensor(t, b) for b, t in zip(o_broadcastables, o_dtypes)] ny = y.type.ndim
if nx not in (1,2): raise TypeError('not matrix or vector', x)
if ny not in (1,2): raise TypeError('not matrix or vector', y)
if nx == 2 and ny == 2:
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
elif nx == 1 and ny == 2:
bz = [y.type.broadcastable[1]]
elif nx == 2 and ny == 1:
bz = [x.type.broadcastable[0]]
else:
bz = []
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )): def perform(self, node, (x, y), (z, )):
z[0] = numpy.dot(x, y) z[0] = numpy.dot(x, y)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0:
return gz * y, gz * x
if x.type.ndim == 1 and y.type.ndim > 1:
return dot(gz, y.T), outer(x.T, gz)
if x.type.ndim > 1 and y.type.ndim == 1:
return outer(gz, y.T), dot(x.T, gz)
return dot(gz, y.T), dot(x.T, gz) return dot(gz, y.T), dot(x.T, gz)
def __str__(self): def __str__(self):
return "dot" return "dot"
dot = Dot() dot = Dot()
class Outer(Op):
""" Compute vector-vector outer product
"""
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
x, y = inputs
nx = x.type.ndim
ny = y.type.ndim
if nx != 1: raise TypeError('not vector', x)
if ny != 1: raise TypeError('not vector', y)
bz = [x.type.broadcastable[0], y.type.broadcastable[0]]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(scal.upcast(*i_dtypes), bz)]
return Apply(self, inputs, outputs)
def perform(self, node, (x, y), (z, )):
z[0] = numpy.outer(x, y)
def grad(self, (x, y), (gz,)):
return dot(gz, y), dot(x, gz) #no transposing necessary
def __str__(self):
return "outer"
outer = Outer()
class Gemm(Op): class Gemm(Op):
E_rank = 'gemm only works for rank 2' E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument' E_scalar = 'gemm requires scalar argument'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论