提交 54bf7b7e authored 作者: Frederic's avatar Frederic

Fixes to choose.

上级 d9735d34
......@@ -5200,7 +5200,15 @@ class Choose(Op):
if isinstance(node.inputs[1], TensorVariable):
# We have padded node.inputs[0] to the right number of
# dimensions for the output
return[(shapes[0])]
l = []
for sh1, sh2, b1 in zip(shapes[0],
shapes[1][1:],
node.inputs[0].broadcastable):
if b1:
l.append(sh2)
else:
l.append(sh1)
return [tuple(l)]
else:
import theano.typed_list
assert isinstance(node.inputs[1],
......@@ -5221,7 +5229,8 @@ class Choose(Op):
'choose first argument must have an [u]int* dtype. Got %s.'
% a.dtype)
if isinstance(choices, (tuple, list)):
if isinstance(choices, (tuple, list,
theano.typed_list.TypedListVariable)):
choice = theano.typed_list.make_list(choices)
choice_ndim = choice.ttype.ndim
choice_bcast = choice.ttype.broadcastable
......@@ -5230,7 +5239,24 @@ class Choose(Op):
choice_ndim = choice.ndim - 1
choice_bcast = choice.broadcastable[1:]
out_ndim = numpy.max([a.ndim, choice_ndim])
# Make explicit all added broadcastable dimensions.
a = shape_padleft(a, out_ndim - a.ndim)
if len(choice_bcast) != out_ndim:
if isinstance(choice.type, TensorType):
choice = choice.dimshuffle(0,
*(('x',) *(out_ndim - choice_ndim) +
tuple(range(1, choice.ndim))))
choice_ndim = choice.ndim - 1
choice_bcast = choice.broadcastable[1:]
else:
raise NotImplementedError(
"We currently didn't implemented that case. "
"To make it work, explicitly add dimensions "
"of size one for dimensions that will be broadcasted")
assert isinstance(node.inputs[1],
theano.typed_list.TypedListVariable)
bcast = [False] * out_ndim
for idx, (b1, b2) in enumerate(
zip(a.broadcastable,
......
......@@ -7114,26 +7114,32 @@ class T_Choose(utt.InferShapeTester):
def test_infer_shape(self):
for shp1, shp2 in [
((5, 4), (7, 4)),
((4,), (4,)),
((1, 4), (7, 4)),
((5, 1), (7, 4)),
((5, 4), (1, 4)),
((5, 4), (7, 1)),
((5, 4), (4,)),
((1, 4), (4,)),
((5, 1), (4,)),
((5, 4), (1,)),
((4,), (5, 4)),
((1,), (5, 4)),
((4,), (1, 4)),
((4,), (3, 1)),
((1, 4), (7, 4)),
((4,), (4,)),
((1,), (4,)),
((1, 4), (4,)),
# The following case cause error from NumPy.
# ((5, 4), (1, 4)),
# ((1,), (1,)),
# ((4,), (1,)),
# ((4,), (1, 4)),
# ((4,), (3, 1)),
((4,), (1,)),
((1,), (1,)),
]:
a = tensor.tensor(dtype='int32',
broadcastable=[n == 1 for n in shp1])
c = tensor.tensor(dtype='float32',
broadcastable=[n == 1 for n in shp2])
A = numpy.asarray(numpy.random.rand(*shp1) * 4, dtype='int32')
C = numpy.asarray(numpy.random.rand(*shp2) * 4, dtype='float32')
A = numpy.asarray(numpy.random.rand(*shp1) * shp2[0], dtype='int32')
C = numpy.asarray(numpy.random.rand(*shp2) * shp2[0], dtype='float32')
self._compile_and_check([a, c], # theano.function inputs
[self.op(a, c)], # theano.function outputs
# Always use not square matrix!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论