提交 c42d79a6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some problems and add some tests.

上级 f8b8c31a
...@@ -271,7 +271,7 @@ def local_inplace_gpuager(node): ...@@ -271,7 +271,7 @@ def local_inplace_gpuager(node):
return [gpuger_inplace(*node.inputs)] return [gpuger_inplace(*node.inputs)]
gpuablas_opt_inplace = in2out(LocalOptGroup( gpuablas_opt_inplace = in2out(LocalOptGroup(
local_inplace_gpuagemv, local_inplace_gpuagemm), local_inplace_gpuagemv, local_inplace_gpuagemm, local_inplace_gpuager),
name='gpuablas_opt_inplace') name='gpuablas_opt_inplace')
optdb.register('InplaceGpuaBlasOpt', optdb.register('InplaceGpuaBlasOpt',
gpuablas_opt_inplace, gpuablas_opt_inplace,
......
...@@ -314,7 +314,7 @@ def local_gpua_gemm(node): ...@@ -314,7 +314,7 @@ def local_gpua_gemm(node):
@register_opt() @register_opt()
@op_lifter([tensor.blas.Ger, tensor.blas_c.CGer]) @op_lifter([tensor.blas.Ger, tensor.blas_c.CGer, tensor.blas_scipy.ScipyGer])
def local_gpua_ger(node): def local_gpua_ger(node):
return GpuGer(destructive=node.op.destructive) return GpuGer(destructive=node.op.destructive)
......
from unittest import TestCase from unittest import TestCase
import theano import theano
from theano import tensor
from theano.tests import unittest_tools
from theano.tensor.blas import (gemv_inplace, gemm_inplace, ger_destructive, from theano.tensor.blas import (gemv_inplace, gemm_inplace, ger_destructive,
_dot22) _dot22)
from theano.tensor.tests.test_blas import TestGer
from theano.sandbox.gpuarray.tests.test_basic_ops import makeTester, rand from theano.sandbox.gpuarray import gpuarray_shared_constructor
from theano.sandbox.gpuarray.tests.test_basic_ops import (makeTester, rand,
mode_with_gpu)
from theano.sandbox.gpuarray.blas import (gpugemv_inplace, from theano.sandbox.gpuarray.blas import (gpugemv_inplace, gpugemv_no_inplace,
gpugemm_inplace, gpuger_inplace, gpugemm_inplace, gpugemm_no_inplace,
gpu_dot22) gpuger_inplace, gpuger_no_inplace,
GpuGer, gpu_dot22)
GpuGemvTester = makeTester('GpuGemvTester', GpuGemvTester = makeTester('GpuGemvTester',
...@@ -54,6 +60,13 @@ GpuGerTester = makeTester( ...@@ -54,6 +60,13 @@ GpuGerTester = makeTester(
) )
) )
class TestGpuGer_OpContract(TestCase, unittest_tools.T_OpContractMixin):
def setUp(self):
self.ops = [gpuger_no_inplace, gpuger_inplace]
def clone(self, op):
return GpuGer(destructive=op.destructive)
GpuDot22Tester = makeTester( GpuDot22Tester = makeTester(
'GpuGemmTester', 'GpuGemmTester',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论