提交 63b5ea7f authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

Made corrections from review.

上级 28bdcbe3
...@@ -76,8 +76,9 @@ class HStack(gof.op.Op): ...@@ -76,8 +76,9 @@ class HStack(gof.op.Op):
"""Stack sparse matrices horizontally (column wise). """Stack sparse matrices horizontally (column wise).
:param blocks: Sequence of sparse array of compatible shape. :param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format. :param format: String representing the output format. Defaul
:param dtype: Output dtype. is csc.
:param dtype: Output dtype. Must be specified.
:return: The concatenation of the sparse arrays column wise. :return: The concatenation of the sparse arrays column wise.
...@@ -86,15 +87,15 @@ class HStack(gof.op.Op): ...@@ -86,15 +87,15 @@ class HStack(gof.op.Op):
- The grad implemented is regular, i.e. not structured. - The grad implemented is regular, i.e. not structured.
""" """
def __init__(self, format=None, dtype=None): def __init__(self, dtype, format=None):
if format is None: if format is None:
self.format = 'csc' self.format = 'csc'
else: else:
self.format = format self.format = format
if dtype is None: if dtype is None:
self.dtype = theano.config.floatX raise ValueError('The output dtype must be specified.')
else: self.dtype = dtype
self.dtype = dtype
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
...@@ -152,26 +153,29 @@ def hstack(blocks, format=None, dtype=None): ...@@ -152,26 +153,29 @@ def hstack(blocks, format=None, dtype=None):
This wrap the method hstack from scipy. This wrap the method hstack from scipy.
:param blocks: List of sparse array of compatible shape. :param blocks: List of sparse array of compatible shape.
:param format: String representing the output format. :param format: String representing the output format. Defaul
:param dtype: Output dtype. is csc.
:param dtype: Output dtype. Must be specified.
:return: The concatenation of the sparse array column wise. :return: The concatenation of the sparse array column wise.
:note: :note:
- The number of line of the sparse matrix must agree. - The number of line of the sparse matrix must agree.
- The grad implemented is regular, i.e. not structured. - The grad implemented is regular, i.e. not structured.
""" """
return HStack(format=format, dtype=dtype)(*blocks) if dtype is None:
raise ValueError('The output dtype must be specified.')
return HStack(dtype, format=format)(*blocks)
class VStack(HStack): class VStack(HStack):
"""Stack sparse matrices vertically (row wise). """Stack sparse matrices vertically (row wise).
:param blocks: Sequence of sparse array of compatible shape. :param blocks: Sequence of sparse array of compatible shape.
:param format: String representing the output format. :param format: String representing the output format. Defaul
:param dtype: Output dtype. is csc.
:param dtype: Output dtype. Must be specified.
:return: The concatenation of the sparse arrays row wise. :return: The concatenation of the sparse arrays row wise.
...@@ -211,13 +215,14 @@ class VStack(HStack): ...@@ -211,13 +215,14 @@ class VStack(HStack):
return [(d, ins_shapes[0][1])] return [(d, ins_shapes[0][1])]
def hstack(blocks, format=None, dtype=None): def vstack(blocks, format=None, dtype=None):
"""Stack sparse matrices vertically (row wise). """Stack sparse matrices vertically (row wise).
This wrap the method vstack from scipy. This wrap the method vstack from scipy.
:param blocks: List of sparse array of compatible shape. :param blocks: List of sparse array of compatible shape.
:param format: String representing the output format. :param format: String representing the output format. Defaul
is csc.
:param dtype: Output dtype. :param dtype: Output dtype.
:return: The concatenation of the sparse array row wise. :return: The concatenation of the sparse array row wise.
...@@ -226,6 +231,9 @@ def hstack(blocks, format=None, dtype=None): ...@@ -226,6 +231,9 @@ def hstack(blocks, format=None, dtype=None):
- The number of column of the sparse matrix must agree. - The number of column of the sparse matrix must agree.
- The grad implemented is regular, i.e. not structured. - The grad implemented is regular, i.e. not structured.
""" """
if dtype is None:
raise ValueError('The output dtype must be specified.')
return VStack(format=format, dtype=dtype)(*blocks) return VStack(format=format, dtype=dtype)(*blocks)
...@@ -270,11 +278,8 @@ class AddSSData(gof.op.Op): ...@@ -270,11 +278,8 @@ class AddSSData(gof.op.Op):
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
is_continuous = [(i.dtype in sparse.continuous_dtypes) is_continuous = [(i.dtype in sparse.continuous_dtypes)
for i in inputs] for i in inputs]
derivative = {True: gz, False: None}
if all(is_continuous): return [derivative[b] for b in is_continuous]
return [gz, gz]
else:
return [None] * len(inputs)
def infer_shape(self, node, ins_shapes): def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]] return [ins_shapes[0]]
......
...@@ -163,7 +163,8 @@ class _HVStackTester(utt.InferShapeTester): ...@@ -163,7 +163,8 @@ class _HVStackTester(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
self._compile_and_check(self.x[format], self._compile_and_check(self.x[format],
[self.op_class()(*self.x[format])], [self.op_class(theano.config.floatX)
(*self.x[format])],
self.mat[format], self.mat[format],
self.op_class) self.op_class)
...@@ -171,15 +172,11 @@ class _HVStackTester(utt.InferShapeTester): ...@@ -171,15 +172,11 @@ class _HVStackTester(utt.InferShapeTester):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
for out_f in sparse.sparse_formats: for out_f in sparse.sparse_formats:
for dtype in sparse.float_dtypes: for dtype in sparse.float_dtypes:
eps = None
if dtype == 'float32':
eps = 7e-4
verify_grad_sparse( verify_grad_sparse(
self.op_class(format=out_f, dtype=dtype), self.op_class(format=out_f, dtype=dtype),
self.mat[format], self.mat[format],
structured=False, structured=False,
eps=eps) eps=7e-4)
def _hv_switch(op, expected_function): def _hv_switch(op, expected_function):
...@@ -195,10 +192,6 @@ def _hv_switch(op, expected_function): ...@@ -195,10 +192,6 @@ def _hv_switch(op, expected_function):
def expected_f(self, a, format=None, dtype=None): def expected_f(self, a, format=None, dtype=None):
return expected_function(a, format, dtype) return expected_function(a, format, dtype)
def setUp(self):
super(XStackTester, self).setUp()
return XStackTester return XStackTester
HStackTester = _hv_switch(S2.HStack, sp.hstack) HStackTester = _hv_switch(S2.HStack, sp.hstack)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论