提交 97c61d06 authored 作者: Frederic Bastien's avatar Frederic Bastien

Test that gemv are inserted into the graph. Disable the gemv test on complex as…

Test that gemv are inserted into the graph. Disable the gemv test on complex as they don't get inserted.
上级 0226809b
...@@ -809,6 +809,9 @@ class BaseGemv(object): ...@@ -809,6 +809,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_val = oy_func() oy_val = oy_func()
assert_array_almost_equal(desired_oy, oy_val) assert_array_almost_equal(desired_oy, oy_val)
...@@ -826,6 +829,20 @@ class BaseGemv(object): ...@@ -826,6 +829,20 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
# The only op in the graph is a dot.
# In the gemm case, we create a dot22 for that case
# There is no dot21.
# Creating one is not usefull as this is not faster(in fact it would be slower!
# as more code would be in python, numpy.dot will call gemv itself)
"""
>>> t0=time.time();x=scipy.linalg.blas.fblas.dgemv(1,a.T,b,1,z.T);t1=time.time();print t1-t0
0.00192999839783
>>> t0=time.time();x=numpy.dot(a,b);t1=time.time();print t1-t0
0.00158381462097
"""
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==0
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
...@@ -841,6 +858,9 @@ class BaseGemv(object): ...@@ -841,6 +858,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
...@@ -855,6 +875,9 @@ class BaseGemv(object): ...@@ -855,6 +875,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
...@@ -869,6 +892,9 @@ class BaseGemv(object): ...@@ -869,6 +892,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
...@@ -883,6 +909,9 @@ class BaseGemv(object): ...@@ -883,6 +909,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
...@@ -897,25 +926,23 @@ class BaseGemv(object): ...@@ -897,25 +926,23 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt) oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func() oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v) assert_array_almost_equal(desired_oy, oy_v)
try: class TestSgemv(TestCase, BaseGemv):
class TestSgmev(TestCase, BaseGemv): dtype = float32
dtype = float32
except AttributeError:
class TestSgmev: pass
class TestDgemv(TestCase, BaseGemv): class TestDgemv(TestCase, BaseGemv):
dtype = float64 dtype = float64
try: #The optimization to put Gemv don't work for complex type for now.
class TestCgemv(TestCase, BaseGemv): #class TestCgemv(TestCase, BaseGemv):
dtype = complex64 # dtype = complex64
except AttributeError:
class TestCgemv: pass
class TestZgemv(TestCase, BaseGemv): #class TestZgemv(TestCase, BaseGemv):
dtype = complex128 # dtype = complex128
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论