提交 4855e5b1 authored 作者: Frederic's avatar Frederic

make some ger test run on the gpu.

上级 659a7465
......@@ -17,8 +17,9 @@ import theano.sandbox.cuda as tcn
from theano.tensor.signal.downsample import DownsampleFactorMax, DownsampleFactorMaxGrad
import theano.compile.mode
from theano.tensor.tests.test_blas import BaseGemv
from theano.tensor.tests.test_blas import BaseGemv, TestGer_local_gemm_to_ger
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace, gpu_gemv_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace, gpu_ger_no_inplace
if theano.config.mode=='FAST_COMPILE':
......@@ -258,3 +259,24 @@ class TestGpuGemv(TestCase, BaseGemv,
# the gemv inplace.
gemv = gpu_gemv_inplace
gemv_inplace = gpu_gemv_inplace
class TestGpuGer(TestGer_local_gemm_to_ger):
def setUp(self):
self.mode = theano.compile.get_default_mode().including(
'fast_run', 'gpu')
self.mode = self.mode.excluding('c_blas')
dtype = self.dtype = 'float32' # optimization isn't dtype-dependent
self.A = tensor.tensor(dtype=dtype, broadcastable=(False, False))
self.a = tensor.tensor(dtype=dtype, broadcastable=())
self.x = tensor.tensor(dtype=dtype, broadcastable=(False,))
self.y = tensor.tensor(dtype=dtype, broadcastable=(False,))
self.origval = theano.tensor.blas_scipy.optimizations_enabled
theano.tensor.blas_scipy.optimizations_enabled = False
self.ger = gpu_ger_no_inplace
self.ger_destructive = gpu_ger_inplace
self.gemm = tcn.blas.gpu_gemm_no_inplace
# data on the gpu make the op always inplace
self.ger = gpu_ger_inplace
self.gemm = tcn.blas.gpu_gemm_inplace
......@@ -1403,6 +1403,9 @@ class TestGer_local_gemm_to_ger(TestCase, unittest_tools.TestOptimizationMixin):
self.y = T.tensor(dtype=dtype, broadcastable=(False,))
self.origval = theano.tensor.blas_scipy.optimizations_enabled
theano.tensor.blas_scipy.optimizations_enabled = False
self.ger = ger
self.ger_destructive = ger_destructive
self.gemm = gemm_no_inplace
def tearDown(self):
theano.tensor.blas_scipy.optimizations_enabled = self.origval
......@@ -1431,19 +1434,23 @@ class TestGer_local_gemm_to_ger(TestCase, unittest_tools.TestOptimizationMixin):
def test_outer(self):
f = self.function([self.x, self.y], T.outer(self.x, self.y))
self.assertFunctionContains(f, ger_destructive)
self.assertFunctionContains(f, self.ger_destructive)
def test_A_plus_outer(self):
f = self.function([self.A, self.x, self.y],
self.A + T.outer(self.x, self.y))
self.assertFunctionContains(f, ger)
self.assertFunctionContains(f, self.ger)
def test_A_plus_scaled_outer(self):
f = self.function([self.A, self.x, self.y],
self.A + 0.1 * T.outer(self.x, self.y))
self.assertFunctionContains(f, ger)
self.assertFunctionContains(f, self.ger)
def test_scaled_A_plus_scaled_outer(self):
f = self.function([self.A, self.x, self.y],
0.2 * self.A + 0.1 * T.outer(self.x, self.y))
self.assertFunctionContains(f, gemm_no_inplace)
numpy.asarray(0.2, self.dtype) * self.A +
numpy.asarray(0.1, self.dtype) * T.outer(
self.x, self.y))
# Why gemm? This make the graph simpler did we test that it
# make it faster?
self.assertFunctionContains(f, self.gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论