提交 a4a40a60 authored 作者: abergeron's avatar abergeron

Merge pull request #1910 from nouiz/split_0idx

Allow split with 0 index
......@@ -770,7 +770,7 @@ class test_comparison(unittest.TestCase):
y = theano.tensor.matrix()
m1 = sp.csc_matrix((2, 2), dtype=theano.config.floatX)
m2 = numpy.asarray([[0, 0], [0, 0]])
m2 = numpy.asarray([[0, 0], [0, 0]], dtype=theano.config.floatX)
for func in self.testsDic:
......
......@@ -3150,8 +3150,9 @@ class Split(Op):
if numpy.sum(splits) != len_along_axis:
raise ValueError('The splits sum to %s, expected %s' %
(numpy.sum(splits), len_along_axis))
if not python_all(splits):
raise ValueError('Cannot have a split of zero.')
if python_any([nb < 0 for nb in splits]):
raise ValueError('Split: you try to make an ndarray with'
'negative number of elements.')
# Checking is done, let's roll the splitting algorithm!
# Basically we step along the given axis of x, extracting
......
......@@ -2343,6 +2343,7 @@ def test_batched_dot():
assert result.shape[0] == first_mat_val.shape[0]
def test_batched_tensordot():
first = theano.tensor.tensor4("first")
second = theano.tensor.tensor4("second")
......@@ -2364,10 +2365,10 @@ def test_batched_tensordot():
second_mat_val = numpy.random.rand(10, 4).astype(config.floatX)
result_fn = theano.function([first_mat, second_mat], output)
result = result_fn(first_mat_val, second_mat_val)
print(result.shape)
assert result.shape[0] == first_mat_val.shape[0]
assert len(result.shape) == 1
def test_tensor_values_eq_approx():
#test, inf, -inf and nan equal themself
a = numpy.asarray([-numpy.inf, -1, 0, 1, numpy.inf, numpy.nan])
......@@ -3145,15 +3146,12 @@ class T_Join_and_Split(unittest.TestCase):
b_v = numpy.random.rand(4)
f = theano.function([a, b], [Ha, Hb])
Ha_v, Hb_v = f(a_v, b_v)
print Ha_v
print Hb_v
# The Hessian is always a matrix full of 0
assert Ha_v.shape == (4, 4)
assert Hb_v.shape == (4, 4)
assert numpy.allclose(Ha_v, 0.)
assert numpy.allclose(Hb_v, 0.)
def test_join_concatenate_one_element(self):
''' Fast test of concatenate as this is an alias for join.
also test that we remove the Join op if there is only 1 input'''
......@@ -3547,6 +3545,26 @@ class T_Join_and_Split(unittest.TestCase):
m = self.shared(rng.rand(4, 4).astype(self.floatX))
self.assertRaises(TypeError, self.join_op(), 0, v, m)
def test_split_0elem(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [4, 0])
f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op)
for node in f.maker.fgraph.toposort()])
o1, o2 = f()
assert numpy.allclose(o1, m.get_value(borrow=True))
assert numpy.allclose(o2, m.get_value(borrow=True)[4:])
def test_split_neg(self):
rng = numpy.random.RandomState(seed=utt.fetch_seed())
m = self.shared(rng.rand(4, 6).astype(self.floatX))
o = self.split_op(2)(m, 0, [5, -1])
f = function([], o, mode=self.mode)
assert any([isinstance(node.op, self.split_op)
for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f)
class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论