提交 ec574156 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba UnravelIndex: Handle arbitrary indices ndim and F-order

上级 41868191
......@@ -261,41 +261,60 @@ def numba_funcify_Unique(op, node, **kwargs):
@register_funcify_and_cache_key(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
out_ndim = node.outputs[0].type.ndim
if len(node.outputs) == 1:
@numba_basic.numba_njit(inline="always")
def maybe_expand_dim(arr):
return arr
else:
if out_ndim == 0:
# Creating a tuple of 0d arrays in numba is basically impossible without codegen, so just go to obj_mode
return generate_fallback_impl(op, node=node), None
@numba_basic.numba_njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
c_order = op.order == "C"
inp_ndim = node.inputs[0].type.ndim
transpose_axes = (inp_ndim, *range(inp_ndim))
@numba_basic.numba_njit
def unravelindex(arr, shape):
def unravelindex(indices, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
if c_order:
# C-Order: Reverse shape (ignore dim0), cumulative product, then reverse back
# Strides: [dim1*dim2, dim2, 1]
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
else:
# F-Order: Standard shape, cumulative product
# Strides: [1, dim0, dim0*dim1]
a[1:] = shape[:-1]
a = np.cumprod(a)
# Broadcast with a and shape on the last axis
unraveled_coords = (indices[..., None] // a) % shape
# PyTensor actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
# Then transpose it to the front
# Numba doesn't have moveaxis (why would it), so we use transpose
# res = np.moveaxis(res, -1, 0)
unraveled_coords = unraveled_coords.transpose(transpose_axes)
# This should be a tuple, but the array can be unpacked
# into multiple variables with the same effect by the outer function
# (special case for single entry is handled with an outer function below)
return unraveled_coords
cache_version = 1
cache_key = sha256(
str((type(op), op.order, len(node.outputs))).encode()
str((type(op), op.order, len(node.outputs), cache_version)).encode()
).hexdigest()
return unravelindex, cache_key
if len(node.outputs) == 1:
@numba_basic.numba_njit
def unravel_index_single_item(arr, shape):
# Unpack single entry
(res,) = unravelindex(arr, shape)
return res
return unravel_index_single_item, cache_key
else:
return unravelindex, cache_key
@register_funcify_default_op_cache_key(SearchsortedOp)
......
......@@ -1304,13 +1304,11 @@ class UnravelIndex(Op):
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")
out_type = indices.type.clone(dtype="int64")
return Apply(
self,
[indices, dims],
[
TensorType(dtype="int64", shape=(None,) * indices.type.ndim)()
for i in range(ptb.get_vector_length(dims))
],
[out_type() for _i in range(ptb.get_vector_length(dims))],
)
def infer_shape(self, fgraph, node, input_shapes):
......
import contextlib
from contextlib import nullcontext
import numpy as np
import pytest
......@@ -295,37 +296,48 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
@pytest.mark.parametrize(
"arr, shape, order, exc",
"arr, shape, requires_obj_mode",
[
(
(pt.lscalar(), np.array(9, dtype="int64")),
pt.as_tensor([2, 3, 4]),
True,
),
(
(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
pt.as_tensor([2, 3, 4]),
"C",
None,
False,
),
(
(pt.lvector(), np.array([1, 0], dtype="int64")),
pt.as_tensor([2]),
"C",
None,
False,
),
(
(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
(pt.lmatrix(), np.array([[9, 15, 1], [1, 9, 15]], dtype="int64")),
pt.as_tensor([2, 3, 4]),
"F",
NotImplementedError,
False,
),
],
)
def test_UnravelIndex(arr, shape, order, exc):
def test_UnravelIndex(arr, shape, requires_obj_mode):
arr, test_arr = arr
g = extra_ops.UnravelIndex(order)(arr, shape)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
g_c = extra_ops.UnravelIndex("C")(arr, shape)
g_f = extra_ops.UnravelIndex("F")(arr, shape)
if shape.type.shape == (1,):
outputs = [g_c, g_f]
else:
outputs = [*g_c, *g_f]
cm = (
pytest.warns(UserWarning, match="object mode")
if requires_obj_mode
else nullcontext()
)
with cm:
compare_numba_and_py(
[arr],
g,
outputs,
[test_arr],
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论