提交 4b1ece1b authored 作者: Matthew Koichi Grimes's avatar Matthew Koichi Grimes

added unit tests for theano.sparse.hstack(), vstack()

上级 e2bd5b94
......@@ -1950,7 +1950,7 @@ class Test_getitem(unittest.TestCase):
assert r1.shape == t1.shape
assert numpy.all(t1.toarray() == r1.toarray())
""""
"""
Important: based on a discussion with both Fred and James
The following indexing methods is not supported because the rval
would be a sparse matrix rather than a sparse vector, which is a
......@@ -2463,6 +2463,35 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
return Tester
def test_hstack_vstack():
"""
Tests sparse.hstack and sparse.vstack (as opposed to the HStack and VStack
classes that they wrap).
"""
def make_block(dtype):
return theano.sparse.csr_matrix(name="%s block" % dtype,
dtype=dtype)
def get_expected_dtype(blocks, to_dtype):
if to_dtype is None:
block_dtypes = tuple(b.dtype for b in blocks)
return theano.scalar.upcast(*block_dtypes)
else:
return to_dtype
# a deliberately weird mix of dtypes to stack
dtypes = ('complex128', theano.config.floatX)
blocks = [make_block(dtype) for dtype in dtypes]
for stack_dimension, stack_function in enumerate((theano.sparse.vstack,
theano.sparse.hstack)):
for to_dtype in (None, ) + dtypes:
stacked_blocks = stack_function(blocks, dtype=to_dtype)
expected_dtype = get_expected_dtype(blocks, to_dtype)
assert stacked_blocks.dtype == expected_dtype
def structure_function(f, index=0):
"""Decorator to structure a function wich
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论