提交 9045925f authored 作者: James Bergstra's avatar James Bergstra 提交者: Frederic

refactor test_blas for test_blas_c

上级 e18787c3
......@@ -1018,6 +1018,8 @@ def matrixmultiply(a, b):
class BaseGemv(object):
mode = mode_blas_opt # can be overridden with self.mode
def get_data(self,x_stride=1,y_stride=1):
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
mult = array(1, dtype=self.dtype)
......@@ -1036,10 +1038,10 @@ class BaseGemv(object):
oy = alpha * T.dot(a,x) + beta * y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
self.assertFunctionContains1(oy_func, self.gemv)
oy_val = oy_func()
......@@ -1056,22 +1058,9 @@ class BaseGemv(object):
oy = T.dot(a,x)
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
# The only op in the graph is a dot.
# In the gemm case, we create a dot22 for that case
# There is no dot21.
# Creating one is not useful as this is not faster(in fact it would be slower!
# as more code would be in python, numpy.dot will call gemv itself)
# See ticket 594
"""
>>> t0=time.time();x=scipy.linalg.blas.fblas.dgemv(1,a.T,b,1,z.T);t1=time.time();print t1-t0
0.00192999839783
>>> t0=time.time();x=numpy.dot(a,b);t1=time.time();print t1-t0
0.00158381462097
"""
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==0
self.assertFunctionContains1(oy_func, self.gemv_inplace)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -1086,10 +1075,9 @@ class BaseGemv(object):
oy = alpha * T.dot(a.T,x)+beta*y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
self.assertFunctionContains1(oy_func, self.gemv)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -1103,10 +1091,9 @@ class BaseGemv(object):
oy = alpha * T.dot(a,x[::2])+beta*y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
self.assertFunctionContains1(oy_func, self.gemv)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -1120,10 +1107,9 @@ class BaseGemv(object):
oy = alpha * T.dot(a.T,x[::2])+beta*y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
self.assertFunctionContains1(oy_func, self.gemv)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -1137,10 +1123,9 @@ class BaseGemv(object):
oy = alpha * T.dot(a,x)+beta*y[::2]
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
self.assertFunctionContains1(oy_func, self.gemv)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
......@@ -1154,21 +1139,24 @@ class BaseGemv(object):
oy = alpha * T.dot(a.T,x)+beta*y[::2]
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_func = theano.function([], oy, mode=self.mode)
topo = oy_func.maker.env.toposort()
assert sum([isinstance(node.op, theano.tensor.blas.Gemv) for node in topo])==1
self.assertFunctionContains1(oy_func, self.gemv)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
class TestSgemv(TestCase, BaseGemv):
class TestSgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
dtype = float32
gemv = theano.tensor.blas.gemv_no_inplace
gemv_inplace = theano.tensor.blas.gemv_inplace
class TestDgemv(TestCase, BaseGemv):
class TestDgemv(TestCase, BaseGemv, unittest_tools.TestOptimizationMixin):
dtype = float64
gemv = theano.tensor.blas.gemv_no_inplace
gemv_inplace = theano.tensor.blas.gemv_inplace
#The optimization to put Gemv don't work for complex type for now.
# See ticket 653.
......@@ -1373,27 +1361,7 @@ class TestGer_make_thunk(TestCase):
def test_c128_1_9(self): return self.given_dtype('complex128', 1, 9)
# TODO: Refactor and add to this base class as we refactor test code.
class TestOptimizationMixin(object):
def assertFunctionContains(self, f, op, min=1, max=sys.maxint):
toposort = f.maker.env.toposort()
matches = [node for node in toposort if node.op == op]
assert (min <= len(matches) <= max), toposort
def assertFunctionContains0(self, f, op):
return self.assertFunctionContains(f, op, min=0, max=0)
def assertFunctionContains1(self, f, op):
return self.assertFunctionContains(f, op, min=1, max=1)
def assertFunctionContainsN(self, f, op, N):
return self.assertFunctionContains(f, op, min=N, max=N)
def SkipTest(self):
raise Exception('how do I skip this test properly?')
class TestGer_local_gemm_to_ger(TestCase, TestOptimizationMixin):
class TestGer_local_gemm_to_ger(TestCase, unittest_tools.TestOptimizationMixin):
def setUp(self):
self.mode = theano.compile.get_default_mode().including('fast_run')
......
......@@ -5,11 +5,13 @@ import numpy
import theano.tensor as T
from theano.configparser import config, AddConfigVar, StrParam
AddConfigVar('unittests.rseed',
"Seed to use for randomized unit tests. Special value 'random' means using a seed of None.",
StrParam(666),
in_c_key=False)
def fetch_seed(pseed=None):
"""
Returns the seed to use for running the unit tests.
......@@ -38,6 +40,7 @@ def fetch_seed(pseed=None):
return seed
def seed_rng(pseed=None):
"""
Seeds numpy's random number generator with the value returned by fetch_seed.
......@@ -51,6 +54,7 @@ def seed_rng(pseed=None):
numpy.random.seed(seed)
return seed
def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
"""
Wrapper for tensor/basic.py:verify_grad
......@@ -72,3 +76,24 @@ def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
# raise
#
verify_grad.E_grad = T.verify_grad.E_grad
class TestOptimizationMixin(object):
def assertFunctionContains(self, f, op, min=1, max=sys.maxint):
toposort = f.maker.env.toposort()
matches = [node for node in toposort if node.op == op]
assert (min <= len(matches) <= max), toposort
def assertFunctionContains0(self, f, op):
return self.assertFunctionContains(f, op, min=0, max=0)
def assertFunctionContains1(self, f, op):
return self.assertFunctionContains(f, op, min=1, max=1)
def assertFunctionContainsN(self, f, op, N):
return self.assertFunctionContains(f, op, min=N, max=N)
def SkipTest(self):
raise Exception('how do I skip this test properly?')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论