提交 d9237bf3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add tests for infer_shape in blocksparse.

上级 83d75317
...@@ -7,7 +7,8 @@ import theano.tests.unittest_tools as utt ...@@ -7,7 +7,8 @@ import theano.tests.unittest_tools as utt
import theano.tensor.nnet.tests.test_blocksparse import theano.tensor.nnet.tests.test_blocksparse
import theano.sandbox.cuda as cuda_ndarray import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.blocksparse import (GpuSparseBlockOuter, from theano.sandbox.cuda.blocksparse import (GpuSparseBlockGemv,
GpuSparseBlockOuter,
gpu_sparse_block_gemv, gpu_sparse_block_gemv,
gpu_sparse_block_outer) gpu_sparse_block_outer)
from theano.sandbox.cuda.var import float32_shared_constructor from theano.sandbox.cuda.var import float32_shared_constructor
...@@ -28,6 +29,8 @@ class BlockSparse_Gemv_and_Outer( ...@@ -28,6 +29,8 @@ class BlockSparse_Gemv_and_Outer(
self.mode = mode_with_gpu.excluding('constant_folding') self.mode = mode_with_gpu.excluding('constant_folding')
self.gemv_op = gpu_sparse_block_gemv self.gemv_op = gpu_sparse_block_gemv
self.outer_op = gpu_sparse_block_outer self.outer_op = gpu_sparse_block_outer
self.gemv_class = GpuSparseBlockGemv
self.outer_class = GpuSparseBlockOuter
# This test is temporarily disabled since we disabled the output_merge # This test is temporarily disabled since we disabled the output_merge
# and alpha_merge optimizations for blocksparse due to brokeness. # and alpha_merge optimizations for blocksparse due to brokeness.
......
""" """
Tests for block sparse dot Tests for block sparse dot
""" """
import unittest
import numpy import numpy
from numpy.random import randn from numpy.random import randn
...@@ -10,15 +8,12 @@ import theano ...@@ -10,15 +8,12 @@ import theano
from theano import tensor from theano import tensor
import theano.tests.unittest_tools as utt import theano.tests.unittest_tools as utt
from theano.tensor.nnet.blocksparse import sparse_block_dot, \ from theano.tensor.nnet.blocksparse import (
sparse_block_gemv, sparse_block_outer sparse_block_dot, sparse_block_gemv, sparse_block_outer,
SparseBlockGemv, SparseBlockOuter)
class BlockSparse_Gemv_and_Outer(unittest.TestCase):
def runTest(self):
pass
class BlockSparse_Gemv_and_Outer(utt.InferShapeTester):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
mode = None mode = None
...@@ -29,6 +24,8 @@ class BlockSparse_Gemv_and_Outer(unittest.TestCase): ...@@ -29,6 +24,8 @@ class BlockSparse_Gemv_and_Outer(unittest.TestCase):
) )
self.gemv_op = sparse_block_gemv self.gemv_op = sparse_block_gemv
self.outer_op = sparse_block_outer self.outer_op = sparse_block_outer
self.gemv_class = SparseBlockGemv
self.outer_class = SparseBlockOuter
@staticmethod @staticmethod
def gemv_data(): def gemv_data():
...@@ -280,3 +277,40 @@ class BlockSparse_Gemv_and_Outer(unittest.TestCase): ...@@ -280,3 +277,40 @@ class BlockSparse_Gemv_and_Outer(unittest.TestCase):
o_val, x_val, y_val, xIdx_val, yIdx_val) o_val, x_val, y_val, xIdx_val, yIdx_val)
utt.assert_allclose(ref_out, th_out) utt.assert_allclose(ref_out, th_out)
def test_dot_infershape(self):
b = tensor.fmatrix()
W = tensor.ftensor4()
h = tensor.ftensor3()
iIdx = tensor.imatrix()
oIdx = tensor.imatrix()
self._compile_and_check([W, h, iIdx, b, oIdx],
[sparse_block_dot(W, h, iIdx, b, oIdx)],
self.gemv_data(),
self.gemv_class)
def test_gemv_infershape(self):
b = tensor.fmatrix()
W = tensor.ftensor4()
h = tensor.ftensor3()
iIdx = tensor.imatrix()
oIdx = tensor.imatrix()
self._compile_and_check(
[W, h, iIdx, b, oIdx],
[self.gemv_op(b.take(oIdx, axis=0), W, h, iIdx, oIdx)],
self.gemv_data(),
self.gemv_class)
def test_outer_infershape(self):
o = tensor.ftensor4()
x = tensor.ftensor3()
y = tensor.ftensor3()
xIdx = tensor.imatrix()
yIdx = tensor.imatrix()
self._compile_and_check([o, x, y, xIdx, yIdx],
[self.outer_op(o, x, y, xIdx, yIdx)],
self.outer_data(),
self.outer_class)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论