提交 3a7f2141 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Argmax: Fix axis=None

上级 c03cf9a6
......@@ -538,10 +538,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
@register_funcify_default_op_cache_key(Argmax)
def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
x_ndim = x_at.ndim
x_pt = node.inputs[0]
x_ndim = x_pt.ndim
if x_ndim == 0:
......@@ -550,7 +548,10 @@ def numba_funcify_Argmax(op, node, **kwargs):
return np.array(0, dtype="int64")
else:
axes = tuple(int(ax) for ax in axis)
if axis is None:
axes = tuple(range(x_ndim))
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
......@@ -584,7 +585,8 @@ def numba_funcify_Argmax(op, node, **kwargs):
return max_idx_res
return argmax
cache_version = 1
return argmax, cache_version
@register_funcify_default_op_cache_key(Dot)
......
......@@ -539,6 +539,11 @@ def test_Max(x, axes, exc):
[0, 1],
None,
),
(
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
None,
None,
),
],
)
def test_Argmax(x, axes, exc):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论