提交 7f15e04a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

And make the opt test work.

上级 6c77f4a6
import numpy
from numpy.random import randn
from unittest import TestCase
from nose.plugins.skip import SkipTest
import theano import theano
from theano import tensor from theano import tensor
import theano.tests.unittest_tools as utt import theano.tests.unittest_tools as utt
import numpy import theano.sandbox.cuda as cuda_ndarray
from numpy.random import randn if cuda_ndarray.cuda_available == False:
raise SkipTest('Optional package cuda disabled')
from theano.sandbox.cuda.blocksparse import (sparse_block_dot_SS, from theano.sandbox.cuda.blocksparse import (sparse_block_dot_SS,
sparse_block_gemv_ss, sparse_block_gemv_ss,
sparse_block_gemv_ss_inplace) sparse_block_gemv_ss_inplace,
sparse_block_outer_ss, sparse_block_outer_ss,
sparse_block_outer_ss_inplace) sparse_block_outer_ss_inplace)
if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu')
def blocksparse_data(): def blocksparse_data():
nInputBlock = 128 nInputBlock = 128
...@@ -109,14 +122,14 @@ class TestBlockSparseDot(TestCase, utt.TestOptimizationMixin): ...@@ -109,14 +122,14 @@ class TestBlockSparseDot(TestCase, utt.TestOptimizationMixin):
o = sparse_block_dot_SS(W, h, iIdx, b, oIdx) o = sparse_block_dot_SS(W, h, iIdx, b, oIdx)
f = theano.function([W, h, iIdx, b, oIdx], o) f = theano.function([W, h, iIdx, b, oIdx], o, mode=mode_with_gpu)
self.assertFunctionContains0(f, sparse_block_gemv_ss) self.assertFunctionContains0(f, sparse_block_gemv_ss)
self.assertFunctionContains1(f, sparse_block_gemv_ss_inplace) self.assertFunctionContains1(f, sparse_block_gemv_ss_inplace)
gW = theano.grad(o.sum(), [W]) gW = theano.grad(o.sum(), [W])
f = theano.function([W, h, iIdx, b, oIdx], gW) f = theano.function([W, h, iIdx, b, oIdx], gW, mode=mode_with_gpu)
self.assertFunctionContains0(f, sparse_block_outer_ss) self.assertFunctionContains0(f, sparse_block_outer_ss)
self.assertFunctionContains1(f, sparse_block_outer_ss_inplace) self.assertFunctionContains1(f, sparse_block_outer_ss_inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论