提交 6f5a1844 authored 作者: James Bergstra's avatar James Bergstra

documentation and revisions to blas optimization pipeline

上级 d36528cb
差异被折叠。
...@@ -83,7 +83,7 @@ class t_gemm(TestCase): ...@@ -83,7 +83,7 @@ class t_gemm(TestCase):
Gemm.debug = True Gemm.debug = True
try: try:
g = gemm_inplace([1.], 1., [1.], [1.], 1.) g = gemm_inplace([1.], 1., [1.], [1.], 1.)
except ValueError, e: except TypeError, e:
if e[0] is Gemm.E_rank: if e[0] is Gemm.E_rank:
return return
self.fail() self.fail()
...@@ -91,7 +91,7 @@ class t_gemm(TestCase): ...@@ -91,7 +91,7 @@ class t_gemm(TestCase):
def test0(self): def test0(self):
try: try:
self.cmp(1., 0., 1.0, 1.0, 1.0) self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e: except TypeError, e:
if e[0] is Gemm.E_rank: if e[0] is Gemm.E_rank:
return return
self.fail() self.fail()
...@@ -99,7 +99,7 @@ class t_gemm(TestCase): ...@@ -99,7 +99,7 @@ class t_gemm(TestCase):
def test2(self): def test2(self):
try: try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0) self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e: except TypeError, e:
self.assertTrue(e[0] == Gemm.E_rank) self.assertTrue(e[0] == Gemm.E_rank)
return return
self.fail() self.fail()
...@@ -124,14 +124,14 @@ class t_gemm(TestCase): ...@@ -124,14 +124,14 @@ class t_gemm(TestCase):
self.rand(3,5), self.rand(5,4), -1.0) self.rand(3,5), self.rand(5,4), -1.0)
def test_factorised_scalar(self): def test_factorised_scalar(self):
a=T.matrix() a=T.dmatrix()
b=T.matrix() b=T.dmatrix()
c=T.matrix() c=T.dmatrix()
s=theano.shared(numpy.zeros((5,5))) s=theano.shared(numpy.zeros((5,5)))
lr1=T.constant(0.01) lr1=T.constant(0.01).astype('float64')
lr2=T.constant(2) lr2=T.constant(2).astype('float64')
l2_reg=T.constant(0.0001) l2_reg=T.constant(0.0001).astype('float64')
#test constant merge with gemm #test constant merge with gemm
f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s},mode=mode_not_fast_compile).maker.env.toposort() f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s},mode=mode_not_fast_compile).maker.env.toposort()
...@@ -195,9 +195,10 @@ class t_gemm(TestCase): ...@@ -195,9 +195,10 @@ class t_gemm(TestCase):
"""test that dot args can be aliased""" """test that dot args can be aliased"""
Z = shared(self.rand(2,2)) Z = shared(self.rand(2,2))
A = shared(self.rand(2,2)) A = shared(self.rand(2,2))
f = inplace_func([], gemm_inplace(Z, 1.0, A, A, 1.0)) one = T.constant(1.0).astype(Z.dtype)
f = inplace_func([], gemm_inplace(Z, one, A, A, one))
f() f()
f = inplace_func([], gemm_inplace(Z, 1.0, A, A.T, 1.0)) f = inplace_func([], gemm_inplace(Z, one, A, A.T, one))
f() f()
def test_transposes(self): def test_transposes(self):
...@@ -451,7 +452,8 @@ def test_gemm_opt_double_gemm(): ...@@ -451,7 +452,8 @@ def test_gemm_opt_double_gemm():
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()] ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c] i = [X,Y,Z,a,b, R, S, c]
o = [a * T.dot(X,Y) + gemm_inplace(Z, b, S.T, R.T, 1.0)] o = [(a * T.dot(X,Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))]
try: try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, f = inplace_func([Param(ii, mutable=True) for ii in i],o,
mode='FAST_RUN') mode='FAST_RUN')
...@@ -765,8 +767,9 @@ def test_dot_vm(): ...@@ -765,8 +767,9 @@ def test_dot_vm():
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value())) assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value()))
# Assert that the dot was optimized somehow
assert sum([isinstance(node.op, T.Dot) for node in assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 1 f.maker.env.toposort() ]) == 0
def test_dot_mv(): def test_dot_mv():
''' Test matrix dot vector ''' ''' Test matrix dot vector '''
...@@ -779,8 +782,9 @@ def test_dot_mv(): ...@@ -779,8 +782,9 @@ def test_dot_mv():
# Assert they produce the same output # Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(m.get_value(), v.get_value())) assert numpy.allclose(f(), numpy.dot(m.get_value(), v.get_value()))
# Assert that the dot was optimized somehow
assert sum([isinstance(node.op, T.Dot) for node in assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 1 f.maker.env.toposort() ]) == 0
class TestGemv(TestCase): class TestGemv(TestCase):
def test_gemv1(self): def test_gemv1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论