提交 e8ee0256 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

More corrections.

上级 63b5ea7f
...@@ -76,7 +76,7 @@ class HStack(gof.op.Op): ...@@ -76,7 +76,7 @@ class HStack(gof.op.Op):
"""Stack sparse matrices horizontally (column wise). """Stack sparse matrices horizontally (column wise).
:param blocks: Sequence of sparse array of compatible shape. :param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format. Defaul :param format: String representing the output format. Default
is csc. is csc.
:param dtype: Output dtype. Must be specified. :param dtype: Output dtype. Must be specified.
...@@ -87,7 +87,7 @@ class HStack(gof.op.Op): ...@@ -87,7 +87,7 @@ class HStack(gof.op.Op):
- The grad implemented is regular, i.e. not structured. - The grad implemented is regular, i.e. not structured.
""" """
def __init__(self, dtype, format=None): def __init__(self, format=None, dtype=None):
if format is None: if format is None:
self.format = 'csc' self.format = 'csc'
else: else:
...@@ -123,19 +123,24 @@ class HStack(gof.op.Op): ...@@ -123,19 +123,24 @@ class HStack(gof.op.Op):
is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes) is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes)
for i in range(len(inputs))] for i in range(len(inputs))]
if all(is_continuous): if _is_sparse_variable(gz):
if _is_sparse_variable(gz): gz = sparse.DenseFromSparse()(gz)
gz = sparse.DenseFromSparse()(gz)
split = tensor.Split(len(inputs))(gz, 1,
split = tensor.Split(len(inputs))(gz, 1, tensor.stack(
tensor.stack( *[x.shape[1]
*[x.shape[1] for x in inputs]))
for x in inputs])) if not isinstance(split, list):
if not isinstance(split, list): split = [split]
split = [split]
return [sparse.SparseFromDense(self.format)(s) for s in split] derivative = [sparse.SparseFromDense(self.format)(s) for s in split]
else:
return [None] * len(inputs) def choose(continuous, derivative):
if continuous:
return derivative
else:
return None
return [choose(c, d) for c, d in zip(is_continuous, derivative)]
def infer_shape(self, node, ins_shapes): def infer_shape(self, node, ins_shapes):
def _get(l): def _get(l):
...@@ -153,9 +158,9 @@ def hstack(blocks, format=None, dtype=None): ...@@ -153,9 +158,9 @@ def hstack(blocks, format=None, dtype=None):
This wrap the method hstack from scipy. This wrap the method hstack from scipy.
:param blocks: List of sparse array of compatible shape. :param blocks: List of sparse array of compatible shape.
:param format: String representing the output format. Defaul :param format: String representing the output format. Default
is csc. is csc.
:param dtype: Output dtype. Must be specified. :param dtype: Output dtype.
:return: The concatenation of the sparse array column wise. :return: The concatenation of the sparse array column wise.
...@@ -164,16 +169,17 @@ def hstack(blocks, format=None, dtype=None): ...@@ -164,16 +169,17 @@ def hstack(blocks, format=None, dtype=None):
- The grad implemented is regular, i.e. not structured. - The grad implemented is regular, i.e. not structured.
""" """
blocks = [as_sparse_variable(i) for i in blocks]
if dtype is None: if dtype is None:
raise ValueError('The output dtype must be specified.') dtype = theano.scalar.upcast([i.dtype for i in blocks])
return HStack(dtype, format=format)(*blocks) return HStack(format=format, dtype=dtype)(*blocks)
class VStack(HStack): class VStack(HStack):
"""Stack sparse matrices vertically (row wise). """Stack sparse matrices vertically (row wise).
:param blocks: Sequence of sparse array of compatible shape. :param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format. Defaul :param format: String representing the output format. Default
is csc. is csc.
:param dtype: Output dtype. Must be specified. :param dtype: Output dtype. Must be specified.
...@@ -194,19 +200,24 @@ class VStack(HStack): ...@@ -194,19 +200,24 @@ class VStack(HStack):
is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes) is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes)
for i in range(len(inputs))] for i in range(len(inputs))]
if all(is_continuous): if _is_sparse_variable(gz):
if _is_sparse_variable(gz): gz = sparse.DenseFromSparse()(gz)
gz = sparse.DenseFromSparse()(gz)
split = tensor.Split(len(inputs))(gz, 0,
split = tensor.Split(len(inputs))(gz, 0, tensor.stack(
tensor.stack( *[x.shape[0]
*[x.shape[0] for x in inputs]))
for x in inputs])) if not isinstance(split, list):
if not isinstance(split, list): split = [split]
split = [split]
return [sparse.SparseFromDense(self.format)(s) for s in split] derivative = [sparse.SparseFromDense(self.format)(s) for s in split]
else:
return [None] * len(inputs) def choose(continuous, derivative):
if continuous:
return derivative
else:
return None
return [choose(c, d) for c, d in zip(is_continuous, derivative)]
def infer_shape(self, node, ins_shapes): def infer_shape(self, node, ins_shapes):
def _get(l): def _get(l):
...@@ -221,7 +232,7 @@ def vstack(blocks, format=None, dtype=None): ...@@ -221,7 +232,7 @@ def vstack(blocks, format=None, dtype=None):
This wrap the method vstack from scipy. This wrap the method vstack from scipy.
:param blocks: List of sparse array of compatible shape. :param blocks: List of sparse array of compatible shape.
:param format: String representing the output format. Defaul :param format: String representing the output format. Default
is csc. is csc.
:param dtype: Output dtype. :param dtype: Output dtype.
...@@ -232,8 +243,9 @@ def vstack(blocks, format=None, dtype=None): ...@@ -232,8 +243,9 @@ def vstack(blocks, format=None, dtype=None):
- The grad implemented is regular, i.e. not structured. - The grad implemented is regular, i.e. not structured.
""" """
blocks = [as_sparse_variable(i) for i in blocks]
if dtype is None: if dtype is None:
raise ValueError('The output dtype must be specified.') dtype = theano.scalar.upcast([i.dtype for i in blocks])
return VStack(format=format, dtype=dtype)(*blocks) return VStack(format=format, dtype=dtype)(*blocks)
......
...@@ -163,8 +163,7 @@ class _HVStackTester(utt.InferShapeTester): ...@@ -163,8 +163,7 @@ class _HVStackTester(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
self._compile_and_check(self.x[format], self._compile_and_check(self.x[format],
[self.op_class(theano.config.floatX) [self.op_class(dtype='float64')(*self.x[format])],
(*self.x[format])],
self.mat[format], self.mat[format],
self.op_class) self.op_class)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论