提交 45316d0c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix tests for Repeat on 32-bit-int python

上级 8455b635
......@@ -303,6 +303,25 @@ class RepeatOp(theano.Op):
def make_node(self, x, repeats):
x = basic.as_tensor_variable(x)
repeats = basic.as_tensor_variable(repeats)
if repeats.dtype not in tensor.discrete_dtypes:
raise TypeError("repeats.dtype must be an integer.")
# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
int_bitwidth = theano.gof.cmodule.python_int_bitwidth()
if int_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32:
numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64')
if repeats.dtype in numpy_unsupported_dtypes:
raise TypeError(
("dtypes %s are not supported by numpy.repeat "
"for the 'repeats' parameter, "
% numpy_unsupported_dtypes), repeats.dtype)
if self.axis is None:
out_type = theano.tensor.TensorType(dtype=x.dtype,
broadcastable=[False])
......
......@@ -49,8 +49,7 @@ class TestBinCountOp(utt.InferShapeTester):
assert (np.bincount(a, minlength=5) == f4(a)).all()
def test_infer_shape(self):
for dtype in ('int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64'):
for dtype in tensor.discrete_dtypes:
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
int_bitwidth = theano.gof.cmodule.python_int_bitwidth()
......@@ -184,6 +183,13 @@ class TestRepeatOp(utt.InferShapeTester):
super(TestRepeatOp, self).setUp()
self.op_class = RepeatOp
self.op = RepeatOp()
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
int_bitwidth = theano.gof.cmodule.python_int_bitwidth()
if int_bitwidth == 64:
self.numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32:
self.numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64')
def test_repeatOp(self):
for ndim in range(3):
......@@ -191,19 +197,30 @@ class TestRepeatOp(utt.InferShapeTester):
a = np.random.random((10, ) * ndim).astype(config.floatX)
for axis in self._possible_axis(ndim):
r_var = T.lscalar()
r = 3
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
r_var = T.lvector()
if axis is None:
r = np.random.random_integers(5, size=a.size)
else:
r = np.random.random_integers(5, size=(10,))
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
for dtype in tensor.discrete_dtypes:
r_var = T.scalar(dtype=dtype)
r = numpy.asarray(3, dtype=dtype)
if dtype in self.numpy_unsupported_dtypes:
self.assertRaises(TypeError,
repeat, x, r_var, axis=axis)
else:
f = theano.function([x, r_var],
repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis),
f(a, r))
r_var = T.vector(dtype=dtype)
if axis is None:
r = np.random.random_integers(
5, size=a.size).astype(dtype)
else:
r = np.random.random_integers(
5, size=(10,)).astype(dtype)
f = theano.function([x, r_var],
repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis),
f(a, r))
def test_infer_shape(self):
for ndim in range(4):
......@@ -211,23 +228,31 @@ class TestRepeatOp(utt.InferShapeTester):
a = np.random.random((10, ) * ndim).astype(config.floatX)
for axis in self._possible_axis(ndim):
r_var = T.lscalar()
r = 3
self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
r_var = T.lvector()
if axis is None:
r = np.random.random_integers(5, size=a.size)
else:
r = np.random.random_integers(5, size=(10,))
for dtype in tensor.discrete_dtypes:
r_var = T.scalar(dtype=dtype)
r = numpy.asarray(3, dtype=dtype)
if dtype in self.numpy_unsupported_dtypes:
self.assertRaises(TypeError, repeat, x, r_var)
else:
self._compile_and_check(
[x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
r_var = T.vector(dtype=dtype)
if axis is None:
r = np.random.random_integers(
5, size=a.size).astype(dtype)
else:
r = np.random.random_integers(
5, size=(10,)).astype(dtype)
self._compile_and_check(
[x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
def test_grad(self):
for ndim in range(3):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论