提交 a7827536 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba RavelMultiIndex: Handle arbitrary indices ndim and F-order

上级 ec574156
......@@ -135,65 +135,49 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
vec_indices = node.inputs[0].type.ndim > 0
if order != "C":
raise NotImplementedError(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)
if mode == "raise":
@numba_basic.numba_njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")
elif mode == "wrap":
@numba_basic.numba_njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d
elif mode == "clip":
@numba_basic.numba_njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)
if node.inputs[0].ndim == 0:
@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
if b < 0 or b >= shape[i]:
mode_fn(new_arr, i, 0, b, shape[i])
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return np.array(a.dot(new_arr.T), dtype=np.int64)
else:
@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
# Concatenate indices along last axis
stacked_indices = np.stack(inp[:-1], axis=-1)
# Manage invalid indices
for i, dim_limit in enumerate(shape):
if mode == "wrap":
stacked_indices[..., i] %= dim_limit
elif mode == "clip":
dim_indices = stacked_indices[..., i]
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
else: # raise
dim_indices = stacked_indices[..., i]
invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])
# Cannot call np.any on a boolean
if vec_indices:
invalid_indices = invalid_indices.any()
if invalid_indices:
raise ValueError("invalid entry in coordinates array")
# Calculate Strides based on Order
a = np.ones(len(shape), dtype=np.int64)
if order == "C":
# C-Order: Last dimension moves fastest (Strides: large -> small -> 1)
# For shape (3, 4, 5): Multipliers are (20, 5, 1)
if len(shape) > 1:
a[:-1] = np.cumprod(shape[:0:-1])[::-1]
else: # order == "F"
# F-Order: First dimension moves fastest (Strides: 1 -> small -> large)
# For shape (3, 4, 5): Multipliers are (1, 3, 12)
if len(shape) > 1:
a[1:] = np.cumprod(shape[:-1])
# Dot product indices with strides
# (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support)
return np.asarray((stacked_indices * a).sum(-1))
@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
# no strict argument to this zip because numba doesn't support it
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
return ravelmultiindex
cache_version = 1
return ravelmultiindex, cache_version
@register_funcify_default_op_cache_key(Repeat)
......
......@@ -1371,8 +1371,7 @@ class RavelMultiIndex(Op):
self.order = order
def make_node(self, *inp):
multi_index = [ptb.as_tensor_variable(i) for i in inp[:-1]]
dims = ptb.as_tensor_variable(inp[-1])
*multi_index, dims = map(ptb.as_tensor_variable, inp)
for i in multi_index:
if i.dtype not in int_dtypes:
......@@ -1382,19 +1381,20 @@ class RavelMultiIndex(Op):
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")
out_type = multi_index[0].type.clone(dtype="int64")
return Apply(
self,
[*multi_index, dims],
[TensorType(dtype="int64", shape=(None,) * multi_index[0].type.ndim)()],
[out_type()],
)
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
def perform(self, node, inp, out):
multi_index, dims = inp[:-1], inp[-1]
*multi_index, dims = inp
res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order)
out[0][0] = np.asarray(res, node.outputs[0].dtype)
out[0][0] = np.asarray(res, "int64")
def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
......
......@@ -7,6 +7,7 @@ import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor import extra_ops
from pytensor.tensor.extra_ops import RavelMultiIndex
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -133,35 +134,34 @@ def test_FillDiagonalOffset(a, val, offset):
@pytest.mark.parametrize(
"arr, shape, mode, order, exc",
"arr, shape, mode, exc",
[
(
tuple((pt.lscalar(), v) for v in np.array([0])),
(pt.lvector(), np.array([2])),
"raise",
"C",
None,
),
(
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"C",
None,
),
(
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"C",
None,
),
(
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
tuple(
(pt.lmatrix(), np.broadcast_to(v, (3, 2)).copy())
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"F",
NotImplementedError,
None,
),
(
tuple(
......@@ -169,7 +169,6 @@ def test_FillDiagonalOffset(a, val, offset):
),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"C",
ValueError,
),
(
......@@ -178,7 +177,15 @@ def test_FillDiagonalOffset(a, val, offset):
),
(pt.lvector(), np.array([2, 3, 4])),
"wrap",
"C",
None,
),
(
tuple(
(pt.ltensor3(), np.broadcast_to(v, (2, 2, 3)).copy())
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
(pt.lvector(), np.array([2, 3, 4])),
"wrap",
None,
),
(
......@@ -187,21 +194,30 @@ def test_FillDiagonalOffset(a, val, offset):
),
(pt.lvector(), np.array([2, 3, 4])),
"clip",
"C",
None,
),
(
tuple(
(pt.lmatrix(), np.broadcast_to(v, (2, 3)).copy())
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
(pt.lvector(), np.array([2, 3, 4])),
"clip",
None,
),
],
)
def test_RavelMultiIndex(arr, shape, mode, order, exc):
def test_RavelMultiIndex(arr, shape, mode, exc):
arr, test_arr = zip(*arr, strict=True)
shape, test_shape = shape
g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape)
g_c = RavelMultiIndex(mode, order="C")(*arr, shape)
g_f = RavelMultiIndex(mode, order="F")(*arr, shape)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(
[*arr, shape],
g,
[g_c, g_f],
[*test_arr, test_shape],
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论