提交 4a12110c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1691 from SuperElectric/sparse_stack_fix

Bugfix in sparse.hstack(), sparse.vstack()
...@@ -2302,7 +2302,7 @@ def hstack(blocks, format=None, dtype=None): ...@@ -2302,7 +2302,7 @@ def hstack(blocks, format=None, dtype=None):
blocks = [as_sparse_variable(i) for i in blocks] blocks = [as_sparse_variable(i) for i in blocks]
if dtype is None: if dtype is None:
dtype = theano.scalar.upcast([i.dtype for i in blocks]) dtype = theano.scalar.upcast(*[i.dtype for i in blocks])
return HStack(format=format, dtype=dtype)(*blocks) return HStack(format=format, dtype=dtype)(*blocks)
...@@ -2378,7 +2378,7 @@ def vstack(blocks, format=None, dtype=None): ...@@ -2378,7 +2378,7 @@ def vstack(blocks, format=None, dtype=None):
blocks = [as_sparse_variable(i) for i in blocks] blocks = [as_sparse_variable(i) for i in blocks]
if dtype is None: if dtype is None:
dtype = theano.scalar.upcast([i.dtype for i in blocks]) dtype = theano.scalar.upcast(*[i.dtype for i in blocks])
return VStack(format=format, dtype=dtype)(*blocks) return VStack(format=format, dtype=dtype)(*blocks)
......
...@@ -1950,7 +1950,7 @@ class Test_getitem(unittest.TestCase): ...@@ -1950,7 +1950,7 @@ class Test_getitem(unittest.TestCase):
assert r1.shape == t1.shape assert r1.shape == t1.shape
assert numpy.all(t1.toarray() == r1.toarray()) assert numpy.all(t1.toarray() == r1.toarray())
"""" """
Important: based on a discussion with both Fred and James Important: based on a discussion with both Fred and James
The following indexing methods is not supported because the rval The following indexing methods is not supported because the rval
would be a sparse matrix rather than a sparse vector, which is a would be a sparse matrix rather than a sparse vector, which is a
...@@ -2463,6 +2463,35 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2463,6 +2463,35 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
return Tester return Tester
def test_hstack_vstack():
"""
Tests sparse.hstack and sparse.vstack (as opposed to the HStack and VStack
classes that they wrap).
"""
def make_block(dtype):
return theano.sparse.csr_matrix(name="%s block" % dtype,
dtype=dtype)
def get_expected_dtype(blocks, to_dtype):
if to_dtype is None:
block_dtypes = tuple(b.dtype for b in blocks)
return theano.scalar.upcast(*block_dtypes)
else:
return to_dtype
# a deliberately weird mix of dtypes to stack
dtypes = ('complex128', theano.config.floatX)
blocks = [make_block(dtype) for dtype in dtypes]
for stack_dimension, stack_function in enumerate((theano.sparse.vstack,
theano.sparse.hstack)):
for to_dtype in (None, ) + dtypes:
stacked_blocks = stack_function(blocks, dtype=to_dtype)
expected_dtype = get_expected_dtype(blocks, to_dtype)
assert stacked_blocks.dtype == expected_dtype
def structure_function(f, index=0): def structure_function(f, index=0):
"""Decorator to structure a function wich """Decorator to structure a function wich
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论