提交 6f5a1844 authored 作者: James Bergstra's avatar James Bergstra

documentation and revisions to blas optimization pipeline

上级 d36528cb
差异被折叠。
......@@ -83,7 +83,7 @@ class t_gemm(TestCase):
Gemm.debug = True
try:
g = gemm_inplace([1.], 1., [1.], [1.], 1.)
except ValueError, e:
except TypeError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
......@@ -91,7 +91,7 @@ class t_gemm(TestCase):
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
except TypeError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
......@@ -99,7 +99,7 @@ class t_gemm(TestCase):
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
except TypeError, e:
self.assertTrue(e[0] == Gemm.E_rank)
return
self.fail()
......@@ -124,14 +124,14 @@ class t_gemm(TestCase):
self.rand(3,5), self.rand(5,4), -1.0)
def test_factorised_scalar(self):
a=T.matrix()
b=T.matrix()
c=T.matrix()
a=T.dmatrix()
b=T.dmatrix()
c=T.dmatrix()
s=theano.shared(numpy.zeros((5,5)))
lr1=T.constant(0.01)
lr2=T.constant(2)
l2_reg=T.constant(0.0001)
lr1=T.constant(0.01).astype('float64')
lr2=T.constant(2).astype('float64')
l2_reg=T.constant(0.0001).astype('float64')
#test constant merge with gemm
f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s},mode=mode_not_fast_compile).maker.env.toposort()
......@@ -195,9 +195,10 @@ class t_gemm(TestCase):
"""test that dot args can be aliased"""
Z = shared(self.rand(2,2))
A = shared(self.rand(2,2))
f = inplace_func([], gemm_inplace(Z, 1.0, A, A, 1.0))
one = T.constant(1.0).astype(Z.dtype)
f = inplace_func([], gemm_inplace(Z, one, A, A, one))
f()
f = inplace_func([], gemm_inplace(Z, 1.0, A, A.T, 1.0))
f = inplace_func([], gemm_inplace(Z, one, A, A.T, one))
f()
def test_transposes(self):
......@@ -451,7 +452,8 @@ def test_gemm_opt_double_gemm():
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c]
o = [a * T.dot(X,Y) + gemm_inplace(Z, b, S.T, R.T, 1.0)]
o = [(a * T.dot(X,Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))]
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o,
mode='FAST_RUN')
......@@ -765,8 +767,9 @@ def test_dot_vm():
# Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(v.get_value(), m.get_value()))
# Assert that the dot was optimized somehow
assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 1
f.maker.env.toposort() ]) == 0
def test_dot_mv():
''' Test matrix dot vector '''
......@@ -779,8 +782,9 @@ def test_dot_mv():
# Assert they produce the same output
assert numpy.allclose(f(), numpy.dot(m.get_value(), v.get_value()))
# Assert that the dot was optimized somehow
assert sum([isinstance(node.op, T.Dot) for node in
f.maker.env.toposort() ]) == 1
f.maker.env.toposort() ]) == 0
class TestGemv(TestCase):
def test_gemv1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论