提交 97c61d06 authored 作者: Frederic Bastien's avatar Frederic Bastien

Test that gemv are inserted into the graph. Disable the gemv test on complex as…

Test that gemv are inserted into the graph. Disable the gemv test on complex as they don't get inserted.
上级 0226809b
......@@ -809,6 +809,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_val = oy_func()
assert_array_almost_equal(desired_oy, oy_val)
......@@ -826,6 +829,20 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
# The only op in the graph is a dot.
# In the gemm case, we create a dot22 for that case
# There is no dot21.
# Creating one is not usefull as this is not faster(in fact it would be slower!
# as more code would be in python, numpy.dot will call gemv itself)
"""
>>> t0=time.time();x=scipy.linalg.blas.fblas.dgemv(1,a.T,b,1,z.T);t1=time.time();print t1-t0
0.00192999839783
>>> t0=time.time();x=numpy.dot(a,b);t1=time.time();print t1-t0
0.00158381462097
"""
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==0
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -841,6 +858,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -855,6 +875,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -869,6 +892,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -883,6 +909,9 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -897,25 +926,23 @@ class BaseGemv(object):
oy_func = theano.function([], oy, mode = mode_blas_opt)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
try:
class TestSgmev(TestCase, BaseGemv):
dtype = float32
except AttributeError:
class TestSgmev: pass
class TestSgemv(TestCase, BaseGemv):
dtype = float32
class TestDgemv(TestCase, BaseGemv):
dtype = float64
try:
class TestCgemv(TestCase, BaseGemv):
dtype = complex64
except AttributeError:
class TestCgemv: pass
#The optimization to put Gemv don't work for complex type for now.
#class TestCgemv(TestCase, BaseGemv):
# dtype = complex64
class TestZgemv(TestCase, BaseGemv):
dtype = complex128
#class TestZgemv(TestCase, BaseGemv):
# dtype = complex128
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论