提交 e8ee0256 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

More corrections.

上级 63b5ea7f
......@@ -76,7 +76,7 @@ class HStack(gof.op.Op):
"""Stack sparse matrices horizontally (column wise).
:param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format. Defaul
:param format: String representing the output format. Default
is csc.
:param dtype: Output dtype. Must be specified.
......@@ -87,7 +87,7 @@ class HStack(gof.op.Op):
- The grad implemented is regular, i.e. not structured.
"""
def __init__(self, dtype, format=None):
def __init__(self, format=None, dtype=None):
if format is None:
self.format = 'csc'
else:
......@@ -123,19 +123,24 @@ class HStack(gof.op.Op):
is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes)
for i in range(len(inputs))]
if all(is_continuous):
if _is_sparse_variable(gz):
gz = sparse.DenseFromSparse()(gz)
split = tensor.Split(len(inputs))(gz, 1,
tensor.stack(
*[x.shape[1]
for x in inputs]))
if not isinstance(split, list):
split = [split]
return [sparse.SparseFromDense(self.format)(s) for s in split]
else:
return [None] * len(inputs)
if _is_sparse_variable(gz):
gz = sparse.DenseFromSparse()(gz)
split = tensor.Split(len(inputs))(gz, 1,
tensor.stack(
*[x.shape[1]
for x in inputs]))
if not isinstance(split, list):
split = [split]
derivative = [sparse.SparseFromDense(self.format)(s) for s in split]
def choose(continuous, derivative):
if continuous:
return derivative
else:
return None
return [choose(c, d) for c, d in zip(is_continuous, derivative)]
def infer_shape(self, node, ins_shapes):
def _get(l):
......@@ -153,9 +158,9 @@ def hstack(blocks, format=None, dtype=None):
This wrap the method hstack from scipy.
:param blocks: List of sparse array of compatible shape.
:param format: String representing the output format. Defaul
:param format: String representing the output format. Default
is csc.
:param dtype: Output dtype. Must be specified.
:param dtype: Output dtype.
:return: The concatenation of the sparse array column wise.
......@@ -164,16 +169,17 @@ def hstack(blocks, format=None, dtype=None):
- The grad implemented is regular, i.e. not structured.
"""
blocks = [as_sparse_variable(i) for i in blocks]
if dtype is None:
raise ValueError('The output dtype must be specified.')
return HStack(dtype, format=format)(*blocks)
dtype = theano.scalar.upcast([i.dtype for i in blocks])
return HStack(format=format, dtype=dtype)(*blocks)
class VStack(HStack):
"""Stack sparse matrices vertically (row wise).
:param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format. Defaul
:param format: String representing the output format. Default
is csc.
:param dtype: Output dtype. Must be specified.
......@@ -194,19 +200,24 @@ class VStack(HStack):
is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes)
for i in range(len(inputs))]
if all(is_continuous):
if _is_sparse_variable(gz):
gz = sparse.DenseFromSparse()(gz)
split = tensor.Split(len(inputs))(gz, 0,
tensor.stack(
*[x.shape[0]
for x in inputs]))
if not isinstance(split, list):
split = [split]
return [sparse.SparseFromDense(self.format)(s) for s in split]
else:
return [None] * len(inputs)
if _is_sparse_variable(gz):
gz = sparse.DenseFromSparse()(gz)
split = tensor.Split(len(inputs))(gz, 0,
tensor.stack(
*[x.shape[0]
for x in inputs]))
if not isinstance(split, list):
split = [split]
derivative = [sparse.SparseFromDense(self.format)(s) for s in split]
def choose(continuous, derivative):
if continuous:
return derivative
else:
return None
return [choose(c, d) for c, d in zip(is_continuous, derivative)]
def infer_shape(self, node, ins_shapes):
def _get(l):
......@@ -221,7 +232,7 @@ def vstack(blocks, format=None, dtype=None):
This wrap the method vstack from scipy.
:param blocks: List of sparse array of compatible shape.
:param format: String representing the output format. Defaul
:param format: String representing the output format. Default
is csc.
:param dtype: Output dtype.
......@@ -232,8 +243,9 @@ def vstack(blocks, format=None, dtype=None):
- The grad implemented is regular, i.e. not structured.
"""
blocks = [as_sparse_variable(i) for i in blocks]
if dtype is None:
raise ValueError('The output dtype must be specified.')
dtype = theano.scalar.upcast([i.dtype for i in blocks])
return VStack(format=format, dtype=dtype)(*blocks)
......
......@@ -163,8 +163,7 @@ class _HVStackTester(utt.InferShapeTester):
def test_infer_shape(self):
for format in sparse.sparse_formats:
self._compile_and_check(self.x[format],
[self.op_class(theano.config.floatX)
(*self.x[format])],
[self.op_class(dtype='float64')(*self.x[format])],
self.mat[format],
self.op_class)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论