提交 22c567c0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

ported scipy gemv tests

上级 9b238d6b
import unittest import unittest
import theano import theano
import theano.tensor as T import theano.tensor as T
from theano import function, Mode from theano import function
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet.ConvTransp3D import convTransp3D from theano.tensor.nnet.ConvTransp3D import convTransp3D
from theano.tensor.nnet.ConvGrad3D import convGrad3D from theano.tensor.nnet.ConvGrad3D import convGrad3D
...@@ -13,6 +13,10 @@ from theano import shared ...@@ -13,6 +13,10 @@ from theano import shared
floatX = theano.config.floatX floatX = theano.config.floatX
#TODO: each individual test method should seed rng with utt.fetch_seed()
# as it is right now, setUp does the seeding, so if you run just
# a subset of the tests they will do different things than if you
# run all of them
class DummyConv3D: class DummyConv3D:
"""A dummy version of Conv3D passed to verify_grad """A dummy version of Conv3D passed to verify_grad
......
...@@ -5,6 +5,9 @@ import theano.tensor as T ...@@ -5,6 +5,9 @@ import theano.tensor as T
from theano.printing import pp from theano.printing import pp
import numpy, theano import numpy, theano
from numpy import (arange, array, common_type, complex64, complex128, float32,
float64, newaxis, shape, transpose, zeros)
from numpy.testing import assert_, assert_array_almost_equal
#from numpy.testing import dec #from numpy.testing import dec
#from numpy.testing.noseclasses import KnownFailureTest #from numpy.testing.noseclasses import KnownFailureTest
...@@ -751,3 +754,162 @@ def test_gemv2(): ...@@ -751,3 +754,162 @@ def test_gemv2():
assert sum(isinstance(node.op, Gemv) for node in topo)==1 assert sum(isinstance(node.op, Gemv) for node in topo)==1
if config.mode != 'FAST_COMPILE': if config.mode != 'FAST_COMPILE':
assert topo[-1].op.inplace==True assert topo[-1].op.inplace==True
# The following gemv tests were added in March 2011 by Ian Goodfellow
# and are based on the gemv tests from scipy
# http://projects.scipy.org/scipy/browser/trunk/scipy/linalg/tests/test_fblas.py?rev=6803
# NOTE: At the time these tests were written, theano did not have a
# conjugate function. If such a thing is ever added, the tests involving
# conjugate should be ported over as well.
def matrixmultiply(a, b):
if len(b.shape) == 1:
b_is_vector = True
b = b[:,newaxis]
else:
b_is_vector = False
assert_(a.shape[1] == b.shape[0])
c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
for i in xrange(a.shape[0]):
for j in xrange(b.shape[1]):
s = 0
for k in xrange(a.shape[1]):
s += a[i,k] * b[k, j]
c[i,j] = s
if b_is_vector:
c = c.reshape((a.shape[0],))
return c
class BaseGemv(object):
def get_data(self,x_stride=1,y_stride=1):
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
mult = array(1, dtype = self.dtype)
if self.dtype in [complex64,complex128]:
mult = array(1+1j, dtype = self.dtype)
alpha = array(1., dtype = self.dtype) * mult
beta = array(1., dtype = self.dtype) * mult
a = rng.randn(3,3).astype(self.dtype) * mult
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
return alpha,beta,a,x,y
def test_simple(self):
alpha, beta, a, x, y = [ shared(value) for value in self.get_data() ]
desired_oy = alpha.get_value() * matrixmultiply(a.get_value(),x.get_value()) + beta.get_value() * y.get_value()
oy = alpha * T.dot(a,x) + beta * y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_val = oy_func()
assert_array_almost_equal(desired_oy, oy_val)
def test_default_beta_y(self):
vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs
a = shared(a_v)
x = shared(x_v)
desired_oy = matrixmultiply(a_v, x_v)
oy = T.dot(a,x)
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
def test_simple_transpose(self):
vs = self.get_data()
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ shared(v) for v in vs ]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v
oy = alpha * T.dot(a.T,x)+beta*y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
def test_x_stride(self):
vs = self.get_data(x_stride = 2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ shared(v) for v in vs ]
desired_oy = alpha_v * matrixmultiply(a_v,x_v[::2])+beta_v*y_v
oy = alpha * T.dot(a,x[::2])+beta*y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
def test_x_stride_transpose(self):
vs = self.get_data(x_stride = 2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ shared(v) for v in vs ]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v[::2])+beta_v*y_v
oy = alpha * T.dot(a.T,x[::2])+beta*y
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
def test_y_stride(self):
vs = self.get_data(y_stride = 2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ shared(v) for v in vs ]
desired_oy = alpha_v * matrixmultiply(a_v,x_v)+beta_v*y_v[::2]
oy = alpha * T.dot(a,x)+beta*y[::2]
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
def test_y_stride_transpose(self):
vs = self.get_data(y_stride = 2)
alpha_v, beta_v, a_v, x_v, y_v = vs
alpha, beta, a, x, y = [ shared(v) for v in vs ]
desired_oy = alpha_v * matrixmultiply(transpose(a_v),x_v)+beta_v*y_v[::2]
oy = alpha * T.dot(a.T,x)+beta*y[::2]
oy_func = theano.function([], oy, mode = mode_blas_opt)
oy_v = oy_func()
assert_array_almost_equal(desired_oy, oy_v)
try:
class TestSgmev(TestCase, BaseGemv):
dtype = float32
except AttributeError:
class TestSgmev: pass
class TestDgemv(TestCase, BaseGemv):
dtype = float64
try:
class TestCgemv(TestCase, BaseGemv):
dtype = complex64
except AttributeError:
class TestCgemv: pass
class TestZgemv(TestCase, BaseGemv):
dtype = complex128
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论