提交 f05af839 authored 作者: Frederic's avatar Frederic

Split (join grad), now support split of 0 element.

上级 95b4fec5
......@@ -3150,8 +3150,9 @@ class Split(Op):
if numpy.sum(splits) != len_along_axis:
raise ValueError('The splits sum to %s, expected %s' %
(numpy.sum(splits), len_along_axis))
if not python_all(splits):
raise ValueError('Cannot have a split of zero.')
if python_any([nb < 0 for nb in splits]):
raise ValueError('Split: you try to make an ndarray with'
'negative number of elements.')
# Checking is done, let's roll the splitting algorithm!
# Basically we step along the given axis of x, extracting
......
......@@ -3547,6 +3547,22 @@ class T_Join_and_Split(unittest.TestCase):
m = self.shared(rng.rand(4, 4).astype(self.floatX))
self.assertRaises(TypeError, self.join_op(), 0, v, m)
def test_split_0elem(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [4, 0])
f = function([], o, mode=self.mode)
o1, o2 = f()
assert numpy.allclose(o1, m.get_value(borrow=True))
assert numpy.allclose(o2, m.get_value(borrow=True)[4:])
def test_split_neg(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [5, -1])
f = function([], o, mode=self.mode)
self.assertRaises(ValueError, f)
class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论