提交 58ff4eb7 authored 作者: LegrandNico's avatar LegrandNico

- Only use `make_list` if choices is list or tuple and have inconsisten shapes…

- Only use `make_list` if choices is list or tuple and have inconsisten shapes otherwise use `as_tensor_variable`.- Add test function using tuples and lists of scalars and 3d tensors.- Remove support for TypedListVariable
上级 7be0eaea
...@@ -7136,21 +7136,36 @@ class TestChoose(utt.InferShapeTester): ...@@ -7136,21 +7136,36 @@ class TestChoose(utt.InferShapeTester):
with pytest.raises(TypeError): with pytest.raises(TypeError):
choose(a, b) choose(a, b)
def test_numpy_compare_tuple(self): @pytest.mark.parametrize(
"test_input",
a = tt.tensor3(dtype="int32") [
b = tt.tensor3(dtype="float32") (
c = tt.tensor3(dtype="float32") tt.tensor3(dtype="int32"),
tt.tensor3(dtype="float32"),
A = np.random.randint(0, 2, (2, 1, 1)).astype("int32") tt.tensor3(dtype="float32"),
B = np.asarray(np.random.rand(1, 6, 1), dtype="float32") np.random.randint(0, 2, (2, 1, 1)).astype("int32"),
C = np.asarray(np.random.rand(1, 1, 5), dtype="float32") np.asarray(np.random.rand(1, 6, 1), dtype="float32"),
np.asarray(np.random.rand(1, 1, 5), dtype="float32"),
),
(
tt.vector(dtype="int32"),
tt.scalar(),
tt.scalar(),
[0, 1, 1, 0],
0.1,
0.2,
),
],
)
def test_numpy_compare_tuple(self, test_input):
"""Test with list and tuples of scalars and 3d tensors."""
a, b, c, A, B, C = test_input
for m in self.modes: for m in self.modes:
f = function([a, b, c], choose(a, (b, c), mode=m)) for ls in [list, tuple]:
t_c = f(A, B, C) f = function([a, b, c], choose(a, ls([b, c]), mode=m))
n_c = np.choose(A, (B, C), mode=m) t_c = f(A, B, C)
assert np.allclose(t_c, n_c) n_c = np.choose(A, ls([B, C]), mode=m)
assert np.allclose(t_c, n_c)
def test_infer_shape(self): def test_infer_shape(self):
for shp1, shp2 in [ for shp1, shp2 in [
......
...@@ -7126,27 +7126,12 @@ class Choose(Op): ...@@ -7126,27 +7126,12 @@ class Choose(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
if isinstance(node.inputs[1], TensorVariable): a_shape, choices_shape = shapes
# We have padded node.inputs[0] to the right number of out_shape = theano.tensor.extra_ops.broadcast_shape(
# dimensions for the output a_shape, choices_shape[1:], arrays_are_shapes=True
l = [] )
for sh1, sh2, b1 in zip(
shapes[0], shapes[1][1:], node.inputs[0].broadcastable
):
if b1:
l.append(sh2)
else:
l.append(sh1)
return [tuple(l)]
else:
import theano.typed_list
assert isinstance(node.inputs[1], theano.typed_list.TypedListVariable) return [out_shape]
raise ShapeError("Case not implemented")
shape = shapes[0]
for i in range(len(shapes[0]) - 1):
shape[i] = shapes[1][i]
return [(shape)]
def make_node(self, a, choices): def make_node(self, a, choices):
# Import here as it isn't imported by default and we can't # Import here as it isn't imported by default and we can't
...@@ -7159,39 +7144,28 @@ class Choose(Op): ...@@ -7159,39 +7144,28 @@ class Choose(Op):
"choose first argument must have an [u]int* dtype. Got %s." % a.dtype "choose first argument must have an [u]int* dtype. Got %s." % a.dtype
) )
if isinstance(choices, (tuple, list, theano.typed_list.TypedListVariable)): # Only use make_list if choices have inconsistent shapes
# otherwise use as_tensor_variable
if isinstance(choices, (tuple, list)):
choice = theano.typed_list.make_list(choices) choice = theano.typed_list.make_list(choices)
choice_ndim = choice.ttype.ndim
choice_bcast = choice.ttype.broadcastable
else: else:
choice = as_tensor_variable(choices) choice = as_tensor_variable(choices)
choice_ndim = choice.ndim - 1 (out_shape,) = self.infer_shape(
choice_bcast = choice.broadcastable[1:] None, [tuple(a.shape), tuple(theano.tensor.basic.shape(choice))]
out_ndim = np.max([a.ndim, choice_ndim]) )
# Make explicit all added broadcastable dimensions. bcast = []
a = shape_padleft(a, out_ndim - a.ndim) for s in out_shape:
if len(choice_bcast) != out_ndim: try:
if isinstance(choice.type, TensorType): s_val = theano.get_scalar_constant_value(s)
choice = choice.dimshuffle( except (theano.tensor.basic.NotScalarConstantError, AttributeError):
0, s_val = None
*(("x",) * (out_ndim - choice_ndim) + tuple(range(1, choice.ndim))),
) if s_val == 1:
choice_ndim = choice.ndim - 1 bcast.append(True)
choice_bcast = choice.broadcastable[1:]
else: else:
raise NotImplementedError( bcast.append(False)
"We currently didn't implemented that case. "
"To make it work, explicitly add dimensions "
"of size one for dimensions that will be broadcasted"
)
bcast = [False] * out_ndim
for idx, (b1, b2) in enumerate(
zip(a.broadcastable, (True,) * (out_ndim - choice_ndim) + choice_bcast)
):
if b1 and b2:
bcast[idx] = True
o = TensorType(choice.dtype, bcast) o = TensorType(choice.dtype, bcast)
return Apply(self, [a, choice], [o()]) return Apply(self, [a, choice], [o()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论