提交 f6094481 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed architecture-dependent bug, due to Numpy flaw

上级 a463e474
......@@ -1300,7 +1300,9 @@ class MaxAndArgmax(Op):
return Apply(self, inputs, outputs)
def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.asarray(numpy.max(x, axis))
max_idx[0] = numpy.asarray(numpy.argmax(x, axis), dtype='int32')
# Note: using 'view' is important until Numpy's ticket 870 is resolved.
max_idx[0] = numpy.asarray(numpy.argmax(x, axis), dtype='int32').view(
numpy.int32)
def grad(self, (x, axis), (g_max, g_max_idx)):
# @warning: This only works if axis is 0, else the max is
# broadcasted wrong in the call to eq.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论