提交 63b5ea7f authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

Made corrections from review.

上级 28bdcbe3
......@@ -76,8 +76,9 @@ class HStack(gof.op.Op):
"""Stack sparse matrices horizontally (column wise).
:param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format.
:param dtype: Output dtype.
:param format: String representing the output format. Defaul
is csc.
:param dtype: Output dtype. Must be specified.
:return: The concatenation of the sparse arrays column wise.
......@@ -86,15 +87,15 @@ class HStack(gof.op.Op):
- The grad implemented is regular, i.e. not structured.
"""
def __init__(self, format=None, dtype=None):
def __init__(self, dtype, format=None):
if format is None:
self.format = 'csc'
else:
self.format = format
if dtype is None:
self.dtype = theano.config.floatX
else:
self.dtype = dtype
raise ValueError('The output dtype must be specified.')
self.dtype = dtype
def __eq__(self, other):
return (type(self) == type(other) and
......@@ -152,26 +153,29 @@ def hstack(blocks, format=None, dtype=None):
This wrap the method hstack from scipy.
:param blocks: List of sparse array of compatible shape.
:param format: String representing the output format.
:param dtype: Output dtype.
:param format: String representing the output format. Defaul
is csc.
:param dtype: Output dtype. Must be specified.
:return: The concatenation of the sparse array column wise.
:note:
- The number of line of the sparse matrix must agree.
- The grad implemented is regular, i.e. not structured.
"""
return HStack(format=format, dtype=dtype)(*blocks)
if dtype is None:
raise ValueError('The output dtype must be specified.')
return HStack(dtype, format=format)(*blocks)
class VStack(HStack):
"""Stack sparse matrices vertically (row wise).
:param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format.
:param dtype: Output dtype.
:param format: String representing the output format. Defaul
is csc.
:param dtype: Output dtype. Must be specified.
:return: The concatenation of the sparse arrays row wise.
......@@ -211,13 +215,14 @@ class VStack(HStack):
return [(d, ins_shapes[0][1])]
def hstack(blocks, format=None, dtype=None):
def vstack(blocks, format=None, dtype=None):
"""Stack sparse matrices vertically (row wise).
This wrap the method vstack from scipy.
:param blocks: List of sparse array of compatible shape.
:param format: String representing the output format.
:param format: String representing the output format. Defaul
is csc.
:param dtype: Output dtype.
:return: The concatenation of the sparse array row wise.
......@@ -226,6 +231,9 @@ def hstack(blocks, format=None, dtype=None):
- The number of column of the sparse matrix must agree.
- The grad implemented is regular, i.e. not structured.
"""
if dtype is None:
raise ValueError('The output dtype must be specified.')
return VStack(format=format, dtype=dtype)(*blocks)
......@@ -270,11 +278,8 @@ class AddSSData(gof.op.Op):
def grad(self, inputs, (gz, )):
is_continuous = [(i.dtype in sparse.continuous_dtypes)
for i in inputs]
if all(is_continuous):
return [gz, gz]
else:
return [None] * len(inputs)
derivative = {True: gz, False: None}
return [derivative[b] for b in is_continuous]
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
......
......@@ -163,7 +163,8 @@ class _HVStackTester(utt.InferShapeTester):
def test_infer_shape(self):
for format in sparse.sparse_formats:
self._compile_and_check(self.x[format],
[self.op_class()(*self.x[format])],
[self.op_class(theano.config.floatX)
(*self.x[format])],
self.mat[format],
self.op_class)
......@@ -171,15 +172,11 @@ class _HVStackTester(utt.InferShapeTester):
for format in sparse.sparse_formats:
for out_f in sparse.sparse_formats:
for dtype in sparse.float_dtypes:
eps = None
if dtype == 'float32':
eps = 7e-4
verify_grad_sparse(
self.op_class(format=out_f, dtype=dtype),
self.mat[format],
structured=False,
eps=eps)
eps=7e-4)
def _hv_switch(op, expected_function):
......@@ -195,10 +192,6 @@ def _hv_switch(op, expected_function):
def expected_f(self, a, format=None, dtype=None):
return expected_function(a, format, dtype)
def setUp(self):
super(XStackTester, self).setUp()
return XStackTester
HStackTester = _hv_switch(S2.HStack, sp.hstack)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论