提交 d9735d34 authored 作者: Frederic's avatar Frederic

Better chooce.infer_shape error.

上级 f31a4118
...@@ -5198,6 +5198,8 @@ class Choose(Op): ...@@ -5198,6 +5198,8 @@ class Choose(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
if isinstance(node.inputs[1], TensorVariable): if isinstance(node.inputs[1], TensorVariable):
# We have padded node.inputs[0] to the right number of
# dimensions for the output
return[(shapes[0])] return[(shapes[0])]
else: else:
import theano.typed_list import theano.typed_list
...@@ -5214,13 +5216,28 @@ class Choose(Op): ...@@ -5214,13 +5216,28 @@ class Choose(Op):
# import at the top as it would cause circular import. # import at the top as it would cause circular import.
import theano.typed_list import theano.typed_list
a = as_tensor_variable(a) a = as_tensor_variable(a)
if "int" not in a.dtype:
raise TypeError(
'choose first argument must have an [u]int* dtype. Got %s.'
% a.dtype)
if isinstance(choices, (tuple, list)): if isinstance(choices, (tuple, list)):
choice = theano.typed_list.make_list(choices) choice = theano.typed_list.make_list(choices)
dtype = choice.ttype.dtype choice_ndim = choice.ttype.ndim
choice_bcast = choice.ttype.broadcastable
else: else:
choice = as_tensor_variable(choices) choice = as_tensor_variable(choices)
o = TensorType(choice.dtype, a.broadcastable) choice_ndim = choice.ndim - 1
choice_bcast = choice.broadcastable[1:]
out_ndim = numpy.max([a.ndim, choice_ndim])
a = shape_padleft(a, out_ndim - a.ndim)
bcast = [False] * out_ndim
for idx, (b1, b2) in enumerate(
zip(a.broadcastable,
(True,) * (out_ndim - choice_ndim) + choice_bcast)):
if b1 and b2:
bcast[idx] = True
o = TensorType(choice.dtype, bcast)
return Apply(self, [a, choice], [o()]) return Apply(self, [a, choice], [o()])
def perform(self, node, inputs, (z, )): def perform(self, node, inputs, (z, )):
......
...@@ -7045,6 +7045,7 @@ class T_Power(unittest.TestCase): ...@@ -7045,6 +7045,7 @@ class T_Power(unittest.TestCase):
class T_Choose(utt.InferShapeTester): class T_Choose(utt.InferShapeTester):
op = staticmethod(choose) op = staticmethod(choose)
op_class = Choose op_class = Choose
modes = ['raise', 'wrap', 'clip']
def test_numpy_compare(self): def test_numpy_compare(self):
...@@ -7055,14 +7056,44 @@ class T_Choose(utt.InferShapeTester): ...@@ -7055,14 +7056,44 @@ class T_Choose(utt.InferShapeTester):
dtype='int32') dtype='int32')
B = numpy.asarray(numpy.random.rand(4, 4), dtype='float32') B = numpy.asarray(numpy.random.rand(4, 4), dtype='float32')
modes = ['raise', 'wrap', 'clip'] for m in self.modes:
f = function([a, b], choose(a, b, mode=m))
t_c = f(A, B)
n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c)
def test_broadcasted(self):
a = tensor.scalar(dtype='int32')
b = tensor.matrix(dtype='float32')
for m in modes: # Test when a is broadcastable
A = 3
B = numpy.asarray(numpy.random.rand(4, 4), dtype='float32')
for m in self.modes:
f = function([a, b], choose(a, b, mode=m)) f = function([a, b], choose(a, b, mode=m))
t_c = f(A, B) t_c = f(A, B)
n_c = numpy.choose(A, B, mode=m) n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c) assert numpy.allclose(t_c, n_c)
# Test when the result should be broadcastable
b = theano.tensor.col(dtype='float32')
B = numpy.asarray(numpy.random.rand(4, 1), dtype='float32')
for m in self.modes:
f = function([a, b], choose(a, b, mode=m))
assert choose(a, b, mode=m).broadcastable[0]
t_c = f(A, B)
n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c)
def test_dtype_error(self):
a = tensor.scalar(dtype='float32')
b = tensor.matrix(dtype='float32')
A = 3
B = numpy.asarray(numpy.random.rand(4, 4), dtype='float32')
self.assertRaises(TypeError, choose, a, b)
def test_numpy_compare_tuple(self): def test_numpy_compare_tuple(self):
a = tensor.tensor3(dtype='int32') a = tensor.tensor3(dtype='int32')
...@@ -7074,36 +7105,42 @@ class T_Choose(utt.InferShapeTester): ...@@ -7074,36 +7105,42 @@ class T_Choose(utt.InferShapeTester):
B = numpy.asarray(numpy.random.rand(1, 6, 1), dtype='float32') B = numpy.asarray(numpy.random.rand(1, 6, 1), dtype='float32')
C = numpy.asarray(numpy.random.rand(1, 1, 5), dtype='float32') C = numpy.asarray(numpy.random.rand(1, 1, 5), dtype='float32')
f = function([a, b, c], choose(a, (b, c))) for m in self.modes:
t_c = f(A, B, C) f = function([a, b, c], choose(a, (b, c), mode=m))
n_c = numpy.choose(A, (B, C)) t_c = f(A, B, C)
assert numpy.allclose(t_c, n_c) n_c = numpy.choose(A, (B, C), mode=m)
assert numpy.allclose(t_c, n_c)
def test_infer_shape(self): def test_infer_shape(self):
for shp1, shp2 in [
a = tensor.matrix(dtype='int32') ((5, 4), (7, 4)),
b = tensor.vector(dtype='int32') ((4,), (4,)),
c = tensor.matrix(dtype='int32') ((5, 4), (4,)),
d = tensor.vector(dtype='int32') ((4,), (5, 4)),
A = numpy.asarray(numpy.random.rand(5, 4) * 4, dtype='int32') ((1, 4), (7, 4)),
B = numpy.asarray(numpy.random.rand(4) * 4, dtype='int32') ((1,), (4,)),
C = numpy.asarray(numpy.random.rand(7, 4) * 4, dtype='int32') ((1, 4), (4,)),
D = numpy.asarray(numpy.random.rand(4) * 4, dtype='int32') # The following case cause error from NumPy.
# ((5, 4), (1, 4)),
var1 = [a, b, a, b] # ((1,), (1,)),
var2 = [c, d, b, a] # ((4,), (1,)),
mat1 = [A, B, A, B] # ((4,), (1, 4)),
mat2 = [C, D, B, A] # ((4,), (3, 1)),
]:
for v, m, w, n in zip(var1, mat1, var2, mat2): a = tensor.tensor(dtype='int32',
self._compile_and_check([v, w], # theano.function inputs broadcastable=[n == 1 for n in shp1])
[self.op(v, w)], # theano.function outputs c = tensor.tensor(dtype='float32',
# Always use not square matrix! broadcastable=[n == 1 for n in shp2])
# inputs data A = numpy.asarray(numpy.random.rand(*shp1) * 4, dtype='int32')
[m, n], C = numpy.asarray(numpy.random.rand(*shp2) * 4, dtype='float32')
# Op that should be removed from the graph. self._compile_and_check([a, c], # theano.function inputs
self.op_class) [self.op(a, c)], # theano.function outputs
# Always use not square matrix!
# inputs data
[A, C],
# Op that should be removed from the graph.
self.op_class)
# Disabled as it isn't implemented. # Disabled as it isn't implemented.
def ___test_infer_shape_tuple(self): def ___test_infer_shape_tuple(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论