提交 ec574156 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba UnravelIndex: Handle arbitrary indices ndim and F-order

上级 41868191
...@@ -261,41 +261,60 @@ def numba_funcify_Unique(op, node, **kwargs): ...@@ -261,41 +261,60 @@ def numba_funcify_Unique(op, node, **kwargs):
@register_funcify_and_cache_key(UnravelIndex) @register_funcify_and_cache_key(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs): def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order out_ndim = node.outputs[0].type.ndim
if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
if len(node.outputs) == 1: if out_ndim == 0:
# Creating a tuple of 0d arrays in numba is basically impossible without codegen, so just go to obj_mode
@numba_basic.numba_njit(inline="always") return generate_fallback_impl(op, node=node), None
def maybe_expand_dim(arr):
return arr
else:
@numba_basic.numba_njit(inline="always") c_order = op.order == "C"
def maybe_expand_dim(arr): inp_ndim = node.inputs[0].type.ndim
return np.expand_dims(arr, 1) transpose_axes = (inp_ndim, *range(inp_ndim))
@numba_basic.numba_njit @numba_basic.numba_njit
def unravelindex(arr, shape): def unravelindex(indices, shape):
a = np.ones(len(shape), dtype=np.int64) a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1] if c_order:
a = np.cumprod(a)[::-1] # C-Order: Reverse shape (ignore dim0), cumulative product, then reverse back
# Strides: [dim1*dim2, dim2, 1]
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
else:
# F-Order: Standard shape, cumulative product
# Strides: [1, dim0, dim0*dim1]
a[1:] = shape[:-1]
a = np.cumprod(a)
# Broadcast with a and shape on the last axis
unraveled_coords = (indices[..., None] // a) % shape
# PyTensor actually returns a `tuple` of these values, instead of an # Then transpose it to the front
# `ndarray`; however, this `ndarray` result should be able to be # Numba doesn't have moveaxis (why would it), so we use transpose
# unpacked into a `tuple`, so this discrepancy shouldn't really matter # res = np.moveaxis(res, -1, 0)
return ((maybe_expand_dim(arr) // a) % shape).T unraveled_coords = unraveled_coords.transpose(transpose_axes)
# This should be a tuple, but the array can be unpacked
# into multiple variables with the same effect by the outer function
# (special case for single entry is handled with an outer function below)
return unraveled_coords
cache_version = 1
cache_key = sha256( cache_key = sha256(
str((type(op), op.order, len(node.outputs))).encode() str((type(op), op.order, len(node.outputs), cache_version)).encode()
).hexdigest() ).hexdigest()
return unravelindex, cache_key if len(node.outputs) == 1:
@numba_basic.numba_njit
def unravel_index_single_item(arr, shape):
# Unpack single entry
(res,) = unravelindex(arr, shape)
return res
return unravel_index_single_item, cache_key
else:
return unravelindex, cache_key
@register_funcify_default_op_cache_key(SearchsortedOp) @register_funcify_default_op_cache_key(SearchsortedOp)
......
...@@ -1304,13 +1304,11 @@ class UnravelIndex(Op): ...@@ -1304,13 +1304,11 @@ class UnravelIndex(Op):
if dims.ndim != 1: if dims.ndim != 1:
raise TypeError("dims must be a 1D array") raise TypeError("dims must be a 1D array")
out_type = indices.type.clone(dtype="int64")
return Apply( return Apply(
self, self,
[indices, dims], [indices, dims],
[ [out_type() for _i in range(ptb.get_vector_length(dims))],
TensorType(dtype="int64", shape=(None,) * indices.type.ndim)()
for i in range(ptb.get_vector_length(dims))
],
) )
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
......
import contextlib import contextlib
from contextlib import nullcontext
import numpy as np import numpy as np
import pytest import pytest
...@@ -295,37 +296,48 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): ...@@ -295,37 +296,48 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"arr, shape, order, exc", "arr, shape, requires_obj_mode",
[ [
(
(pt.lscalar(), np.array(9, dtype="int64")),
pt.as_tensor([2, 3, 4]),
True,
),
( (
(pt.lvector(), np.array([9, 15, 1], dtype="int64")), (pt.lvector(), np.array([9, 15, 1], dtype="int64")),
pt.as_tensor([2, 3, 4]), pt.as_tensor([2, 3, 4]),
"C", False,
None,
), ),
( (
(pt.lvector(), np.array([1, 0], dtype="int64")), (pt.lvector(), np.array([1, 0], dtype="int64")),
pt.as_tensor([2]), pt.as_tensor([2]),
"C", False,
None,
), ),
( (
(pt.lvector(), np.array([9, 15, 1], dtype="int64")), (pt.lmatrix(), np.array([[9, 15, 1], [1, 9, 15]], dtype="int64")),
pt.as_tensor([2, 3, 4]), pt.as_tensor([2, 3, 4]),
"F", False,
NotImplementedError,
), ),
], ],
) )
def test_UnravelIndex(arr, shape, order, exc): def test_UnravelIndex(arr, shape, requires_obj_mode):
arr, test_arr = arr arr, test_arr = arr
g = extra_ops.UnravelIndex(order)(arr, shape) g_c = extra_ops.UnravelIndex("C")(arr, shape)
g_f = extra_ops.UnravelIndex("F")(arr, shape)
cm = contextlib.suppress() if exc is None else pytest.raises(exc) if shape.type.shape == (1,):
outputs = [g_c, g_f]
else:
outputs = [*g_c, *g_f]
cm = (
pytest.warns(UserWarning, match="object mode")
if requires_obj_mode
else nullcontext()
)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
[arr], [arr],
g, outputs,
[test_arr], [test_arr],
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论