提交 d9237bf3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add tests for infer_shape in blocksparse.

上级 83d75317
......@@ -7,7 +7,8 @@ import theano.tests.unittest_tools as utt
import theano.tensor.nnet.tests.test_blocksparse
import theano.sandbox.cuda as cuda_ndarray
from theano.sandbox.cuda.blocksparse import (GpuSparseBlockOuter,
from theano.sandbox.cuda.blocksparse import (GpuSparseBlockGemv,
GpuSparseBlockOuter,
gpu_sparse_block_gemv,
gpu_sparse_block_outer)
from theano.sandbox.cuda.var import float32_shared_constructor
......@@ -28,6 +29,8 @@ class BlockSparse_Gemv_and_Outer(
self.mode = mode_with_gpu.excluding('constant_folding')
self.gemv_op = gpu_sparse_block_gemv
self.outer_op = gpu_sparse_block_outer
self.gemv_class = GpuSparseBlockGemv
self.outer_class = GpuSparseBlockOuter
# This test is temporarily disabled since we disabled the output_merge
# and alpha_merge optimizations for blocksparse due to brokeness.
......
"""
Tests for block sparse dot
"""
import unittest
import numpy
from numpy.random import randn
......@@ -10,15 +8,12 @@ import theano
from theano import tensor
import theano.tests.unittest_tools as utt
from theano.tensor.nnet.blocksparse import sparse_block_dot, \
sparse_block_gemv, sparse_block_outer
class BlockSparse_Gemv_and_Outer(unittest.TestCase):
from theano.tensor.nnet.blocksparse import (
sparse_block_dot, sparse_block_gemv, sparse_block_outer,
SparseBlockGemv, SparseBlockOuter)
def runTest(self):
pass
class BlockSparse_Gemv_and_Outer(utt.InferShapeTester):
def setUp(self):
utt.seed_rng()
mode = None
......@@ -29,6 +24,8 @@ class BlockSparse_Gemv_and_Outer(unittest.TestCase):
)
self.gemv_op = sparse_block_gemv
self.outer_op = sparse_block_outer
self.gemv_class = SparseBlockGemv
self.outer_class = SparseBlockOuter
@staticmethod
def gemv_data():
......@@ -280,3 +277,40 @@ class BlockSparse_Gemv_and_Outer(unittest.TestCase):
o_val, x_val, y_val, xIdx_val, yIdx_val)
utt.assert_allclose(ref_out, th_out)
def test_dot_infershape(self):
b = tensor.fmatrix()
W = tensor.ftensor4()
h = tensor.ftensor3()
iIdx = tensor.imatrix()
oIdx = tensor.imatrix()
self._compile_and_check([W, h, iIdx, b, oIdx],
[sparse_block_dot(W, h, iIdx, b, oIdx)],
self.gemv_data(),
self.gemv_class)
def test_gemv_infershape(self):
b = tensor.fmatrix()
W = tensor.ftensor4()
h = tensor.ftensor3()
iIdx = tensor.imatrix()
oIdx = tensor.imatrix()
self._compile_and_check(
[W, h, iIdx, b, oIdx],
[self.gemv_op(b.take(oIdx, axis=0), W, h, iIdx, oIdx)],
self.gemv_data(),
self.gemv_class)
def test_outer_infershape(self):
o = tensor.ftensor4()
x = tensor.ftensor3()
y = tensor.ftensor3()
xIdx = tensor.imatrix()
yIdx = tensor.imatrix()
self._compile_and_check([o, x, y, xIdx, yIdx],
[self.outer_op(o, x, y, xIdx, yIdx)],
self.outer_data(),
self.outer_class)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论