提交 ed03732c authored 作者: James Bergstra's avatar James Bergstra

Tests for ScipyGer

上级 a10df29e
......@@ -23,9 +23,10 @@ try:
#numpy.dtype('complex64'):scipy.linalg.blas.fblas.cger,
#numpy.dtype('complex128'):scipy.linalg.blas.fblas.zger,
}
optimizations_enabled = True
except ImportError, e:
have_fblas = False
optimizations_enabled = False
class ScipyGer(Ger):
......@@ -61,11 +62,13 @@ class ScipyGer(Ger):
@local_optimizer([ger, ger_destructive])
def use_scipy_ger(node):
if not optimizations_enabled: return
if node.op == ger:
return [ScipyGer(False)(*node.inputs)]
@local_optimizer([ScipyGer(False)])
def make_ger_destructive(node):
if not optimizations_enabled: return
if node.op == ScipyGer(False):
return [ScipyGer(True)(*node.inputs)]
......
......@@ -25,6 +25,7 @@ from theano import Param, shared, config
from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
compile, inplace)
#, constant, eval_outputs)
import theano.tensor.blas_scipy
if config.mode == 'FAST_COMPILE':
mode_not_fast_compile = 'FAST_RUN'
......@@ -1314,6 +1315,9 @@ class TestOptimizationMixin(object):
def assertFunctionContainsN(self, f, op, N):
return self.assertFunctionContains(f, op, min=N, max=N)
def SkipTest(self):
raise Exception('how do I skip this test properly?')
class TestGer_local_gemm_to_ger(TestCase, TestOptimizationMixin):
def setUp(self):
......@@ -1323,6 +1327,11 @@ class TestGer_local_gemm_to_ger(TestCase, TestOptimizationMixin):
self.a = T.tensor(dtype=dtype, broadcastable=())
self.x = T.tensor(dtype=dtype, broadcastable=(False,))
self.y = T.tensor(dtype=dtype, broadcastable=(False,))
self.origval = theano.tensor.blas_scipy.optimizations_enabled
theano.tensor.blas_scipy.optimizations_enabled = False
def tearDown(self):
theano.tensor.blas_scipy.optimizations_enabled = self.origval
def function(self, inputs, outputs):
return theano.function(inputs, outputs, self.mode)
......
import sys
import numpy
import theano
import theano.tensor as tensor
from theano.tensor.blas_scipy import ScipyGer
from test_blas import TestCase, TestOptimizationMixin, gemm_no_inplace
class TestScipyGer(TestCase, TestOptimizationMixin):
def setUp(self):
self.mode = theano.compile.get_default_mode().including('fast_run')
dtype = self.dtype = 'float64' # optimization isn't dtype-dependent
self.A = tensor.tensor(dtype=dtype, broadcastable=(False, False))
self.a = tensor.tensor(dtype=dtype, broadcastable=())
self.x = tensor.tensor(dtype=dtype, broadcastable=(False,))
self.y = tensor.tensor(dtype=dtype, broadcastable=(False,))
self.Aval = numpy.ones((2,3), dtype=dtype)
self.xval = numpy.asarray([1,2], dtype=dtype)
self.yval = numpy.asarray([1.5,2.7,3.9], dtype=dtype)
if not theano.tensor.blas_scipy.optimizations_enabled:
self.SkipTest()
def function(self, inputs, outputs):
return theano.function(inputs, outputs, self.mode)
def run_f(self, f):
f(self.Aval, self.xval, self.yval)
def b(self, bval):
return tensor.as_tensor_variable(numpy.asarray(bval, dtype=self.dtype))
def test_outer(self):
f = self.function([self.x, self.y], tensor.outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=True))
def test_A_plus_outer(self):
f = self.function([self.A, self.x, self.y],
self.A + tensor.outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=False))
self.run_f(f) #DebugMode tests correctness
def test_A_plus_scaled_outer(self):
f = self.function([self.A, self.x, self.y],
self.A + 0.1 * tensor.outer(self.x, self.y))
self.assertFunctionContains(f, ScipyGer(destructive=False))
self.run_f(f) #DebugMode tests correctness
def test_scaled_A_plus_scaled_outer(self):
f = self.function([self.A, self.x, self.y],
0.2 * self.A + 0.1 * tensor.outer(self.x, self.y))
self.assertFunctionContains(f, gemm_no_inplace)
self.run_f(f) #DebugMode tests correctness
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论