提交 d5065649 authored 作者: Chinnadhurai Sankar's avatar Chinnadhurai Sankar 提交者: Pascal Lamblin

cast numpy.prod(arg) to int when arg is empty

上级 56a71e79
...@@ -1232,7 +1232,10 @@ class MaxAndArgmax(Op): ...@@ -1232,7 +1232,10 @@ class MaxAndArgmax(Op):
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes))) transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
kept_shape = transposed_x.shape[:len(keep_axes)] kept_shape = transposed_x.shape[:len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes):] reduced_shape = transposed_x.shape[len(keep_axes):]
new_shape = kept_shape + (numpy.prod(reduced_shape),)
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int
# Otherwise reshape would complain citiing float arg
new_shape = kept_shape + (int(numpy.prod(reduced_shape)),)
reshaped_x = transposed_x.reshape(new_shape) reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论