提交 a4a40a60 authored 作者: abergeron's avatar abergeron

Merge pull request #1910 from nouiz/split_0idx

Allow split with 0 index
...@@ -770,7 +770,7 @@ class test_comparison(unittest.TestCase): ...@@ -770,7 +770,7 @@ class test_comparison(unittest.TestCase):
y = theano.tensor.matrix() y = theano.tensor.matrix()
m1 = sp.csc_matrix((2, 2), dtype=theano.config.floatX) m1 = sp.csc_matrix((2, 2), dtype=theano.config.floatX)
m2 = numpy.asarray([[0, 0], [0, 0]]) m2 = numpy.asarray([[0, 0], [0, 0]], dtype=theano.config.floatX)
for func in self.testsDic: for func in self.testsDic:
......
...@@ -3150,8 +3150,9 @@ class Split(Op): ...@@ -3150,8 +3150,9 @@ class Split(Op):
if numpy.sum(splits) != len_along_axis: if numpy.sum(splits) != len_along_axis:
raise ValueError('The splits sum to %s, expected %s' % raise ValueError('The splits sum to %s, expected %s' %
(numpy.sum(splits), len_along_axis)) (numpy.sum(splits), len_along_axis))
if not python_all(splits): if python_any([nb < 0 for nb in splits]):
raise ValueError('Cannot have a split of zero.') raise ValueError('Split: you try to make an ndarray with'
'negative number of elements.')
# Checking is done, let's roll the splitting algorithm! # Checking is done, let's roll the splitting algorithm!
# Basically we step along the given axis of x, extracting # Basically we step along the given axis of x, extracting
......
...@@ -2343,6 +2343,7 @@ def test_batched_dot(): ...@@ -2343,6 +2343,7 @@ def test_batched_dot():
assert result.shape[0] == first_mat_val.shape[0] assert result.shape[0] == first_mat_val.shape[0]
def test_batched_tensordot(): def test_batched_tensordot():
first = theano.tensor.tensor4("first") first = theano.tensor.tensor4("first")
second = theano.tensor.tensor4("second") second = theano.tensor.tensor4("second")
...@@ -2364,10 +2365,10 @@ def test_batched_tensordot(): ...@@ -2364,10 +2365,10 @@ def test_batched_tensordot():
second_mat_val = numpy.random.rand(10, 4).astype(config.floatX) second_mat_val = numpy.random.rand(10, 4).astype(config.floatX)
result_fn = theano.function([first_mat, second_mat], output) result_fn = theano.function([first_mat, second_mat], output)
result = result_fn(first_mat_val, second_mat_val) result = result_fn(first_mat_val, second_mat_val)
print(result.shape)
assert result.shape[0] == first_mat_val.shape[0] assert result.shape[0] == first_mat_val.shape[0]
assert len(result.shape) == 1 assert len(result.shape) == 1
def test_tensor_values_eq_approx(): def test_tensor_values_eq_approx():
#test, inf, -inf and nan equal themself #test, inf, -inf and nan equal themself
a = numpy.asarray([-numpy.inf, -1, 0, 1, numpy.inf, numpy.nan]) a = numpy.asarray([-numpy.inf, -1, 0, 1, numpy.inf, numpy.nan])
...@@ -3145,15 +3146,12 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3145,15 +3146,12 @@ class T_Join_and_Split(unittest.TestCase):
b_v = numpy.random.rand(4) b_v = numpy.random.rand(4)
f = theano.function([a, b], [Ha, Hb]) f = theano.function([a, b], [Ha, Hb])
Ha_v, Hb_v = f(a_v, b_v) Ha_v, Hb_v = f(a_v, b_v)
print Ha_v
print Hb_v
# The Hessian is always a matrix full of 0 # The Hessian is always a matrix full of 0
assert Ha_v.shape == (4, 4) assert Ha_v.shape == (4, 4)
assert Hb_v.shape == (4, 4) assert Hb_v.shape == (4, 4)
assert numpy.allclose(Ha_v, 0.) assert numpy.allclose(Ha_v, 0.)
assert numpy.allclose(Hb_v, 0.) assert numpy.allclose(Hb_v, 0.)
def test_join_concatenate_one_element(self): def test_join_concatenate_one_element(self):
''' Fast test of concatenate as this is an alias for join. ''' Fast test of concatenate as this is an alias for join.
also test that we remove the Join op if there is only 1 input''' also test that we remove the Join op if there is only 1 input'''
...@@ -3547,6 +3545,26 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3547,6 +3545,26 @@ class T_Join_and_Split(unittest.TestCase):
m = self.shared(rng.rand(4, 4).astype(self.floatX)) m = self.shared(rng.rand(4, 4).astype(self.floatX))
self.assertRaises(TypeError, self.join_op(), 0, v, m) self.assertRaises(TypeError, self.join_op(), 0, v, m)
def test_split_0elem(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [4, 0])
f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op)
for node in f.maker.fgraph.toposort()])
o1, o2 = f()
assert numpy.allclose(o1, m.get_value(borrow=True))
assert numpy.allclose(o2, m.get_value(borrow=True)[4:])
def test_split_neg(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [5, -1])
f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op)
for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f)
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and != """Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论