提交 dd41a16d authored 作者: James Bergstra's avatar James Bergstra

argmax output dtype is now int32

上级 4dfd1f0a
...@@ -970,11 +970,11 @@ class MaxAndArgmax(Op): ...@@ -970,11 +970,11 @@ class MaxAndArgmax(Op):
inputs = [x, axis] inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative
outputs = [tensor(x.type.dtype, broadcastable), outputs = [tensor(x.type.dtype, broadcastable),
tensor(axis.type.dtype, broadcastable)] tensor('int32', broadcastable)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, (x, axis), (max, max_idx)): def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.asarray(numpy.max(x, axis)) max[0] = numpy.asarray(numpy.max(x, axis))
max_idx[0] = numpy.asarray(numpy.argmax(x, axis)) max_idx[0] = numpy.asarray(numpy.argmax(x, axis), dtype='int32')
def grad(self, (x, axis), (g_max, g_max_idx)): def grad(self, (x, axis), (g_max, g_max_idx)):
# @warning: This only works if axis is 0, else the max is # @warning: This only works if axis is 0, else the max is
# broadcasted wrong in the call to eq. # broadcasted wrong in the call to eq.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论