提交 3a7f2141 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Argmax: Fix axis=None

上级 c03cf9a6
...@@ -538,10 +538,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -538,10 +538,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
@register_funcify_default_op_cache_key(Argmax) @register_funcify_default_op_cache_key(Argmax)
def numba_funcify_Argmax(op, node, **kwargs): def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis axis = op.axis
x_at = node.inputs[0] x_pt = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_ndim = x_pt.ndim
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
x_ndim = x_at.ndim
if x_ndim == 0: if x_ndim == 0:
...@@ -550,7 +548,10 @@ def numba_funcify_Argmax(op, node, **kwargs): ...@@ -550,7 +548,10 @@ def numba_funcify_Argmax(op, node, **kwargs):
return np.array(0, dtype="int64") return np.array(0, dtype="int64")
else: else:
axes = tuple(int(ax) for ax in axis) if axis is None:
axes = tuple(range(x_ndim))
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a # NumPy does not support multiple axes for argmax; this is a
# work-around # work-around
...@@ -584,7 +585,8 @@ def numba_funcify_Argmax(op, node, **kwargs): ...@@ -584,7 +585,8 @@ def numba_funcify_Argmax(op, node, **kwargs):
return max_idx_res return max_idx_res
return argmax cache_version = 1
return argmax, cache_version
@register_funcify_default_op_cache_key(Dot) @register_funcify_default_op_cache_key(Dot)
......
...@@ -539,6 +539,11 @@ def test_Max(x, axes, exc): ...@@ -539,6 +539,11 @@ def test_Max(x, axes, exc):
[0, 1], [0, 1],
None, None,
), ),
(
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
None,
None,
),
], ],
) )
def test_Argmax(x, axes, exc): def test_Argmax(x, axes, exc):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论