提交 56b2db6b authored 作者: James Bergstra's avatar James Bergstra

corrected typos in horizontal_stack and vertical_stack, added test cases

上级 92130054
...@@ -967,6 +967,25 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -967,6 +967,25 @@ class T_Join_and_Split(unittest.TestCase):
verify_grad(self, lambda a, b: join(1,a,b), [av, bv], eps=1.0e-4, tol=1.0e-3) verify_grad(self, lambda a, b: join(1,a,b), [av, bv], eps=1.0e-4, tol=1.0e-3)
def test_join_matrix1_using_vertical_stack(self):
a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]]))
b = as_tensor(numpy.array([[7, 8, 9]]))
s = vertical_stack(a, b)
want = numpy.array([[1, 2, 3],[4,5,6],[7, 8, 9]])
self.failUnless((eval_outputs([s]) == want).all())
def test_join_matrix1_using_horizontal_stack(self):
av=numpy.array([[1, 2, 3], [4, 5, 6]], dtype='float32')
bv= numpy.array([[7], [8]],dtype='float32')
a = as_tensor(av)
b = as_tensor(bv)
s = horizontal_stack(a, b)
want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
self.failUnless((eval_outputs([s]) == want).all())
verify_grad(self, lambda a, b: join(1,a,b), [av, bv], eps=1.0e-4, tol=1.0e-3)
def test_join_matrixV(self): def test_join_matrixV(self):
"""variable join axis""" """variable join axis"""
v = numpy.array([[1., 2., 3.], [4., 5., 6.]]) v = numpy.array([[1., 2., 3.], [4., 5., 6.]])
......
...@@ -1651,14 +1651,14 @@ def horizontal_stack(*args): ...@@ -1651,14 +1651,14 @@ def horizontal_stack(*args):
@note: Unlike VerticalStack, we assume that the L{Tensor}s have @note: Unlike VerticalStack, we assume that the L{Tensor}s have
two dimensions. two dimensions.
""" """
assert x.type.ndim == 2 for arg in args:
assert y.type.ndim == 2 assert arg.type.ndim == 2
return concatenate([x,y], axis=1) return concatenate(args, axis=1)
@constructor @constructor
def vertical_stack(*args): def vertical_stack(*args):
assert x.type.ndim == 2 for arg in args:
assert y.type.ndim == 2 assert arg.type.ndim == 2
return concatenate(args, axis=0) return concatenate(args, axis=0)
if 0: #vertical and horizontal stacking are deprecated. Better to use stack() and join(). if 0: #vertical and horizontal stacking are deprecated. Better to use stack() and join().
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论