提交 a7827536 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba RavelMultiIndex: Handle arbitrary indices ndim and F-order

上级 ec574156
...@@ -135,65 +135,49 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs): ...@@ -135,65 +135,49 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
def numba_funcify_RavelMultiIndex(op, node, **kwargs): def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode mode = op.mode
order = op.order order = op.order
vec_indices = node.inputs[0].type.ndim > 0
if order != "C": @numba_basic.numba_njit
raise NotImplementedError( def ravelmultiindex(*inp):
"Numba does not implement `order` in `numpy.ravel_multi_index`" shape = inp[-1]
) # Concatenate indices along last axis
stacked_indices = np.stack(inp[:-1], axis=-1)
if mode == "raise":
# Manage invalid indices
@numba_basic.numba_njit for i, dim_limit in enumerate(shape):
def mode_fn(*args): if mode == "wrap":
raise ValueError("invalid entry in coordinates array") stacked_indices[..., i] %= dim_limit
elif mode == "clip":
elif mode == "wrap": dim_indices = stacked_indices[..., i]
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
@numba_basic.numba_njit(inline="always") else: # raise
def mode_fn(new_arr, i, j, v, d): dim_indices = stacked_indices[..., i]
new_arr[i, j] = v % d invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])
# Cannot call np.any on a boolean
elif mode == "clip": if vec_indices:
invalid_indices = invalid_indices.any()
@numba_basic.numba_njit(inline="always") if invalid_indices:
def mode_fn(new_arr, i, j, v, d): raise ValueError("invalid entry in coordinates array")
new_arr[i, j] = min(max(v, 0), d - 1)
# Calculate Strides based on Order
if node.inputs[0].ndim == 0: a = np.ones(len(shape), dtype=np.int64)
if order == "C":
@numba_basic.numba_njit # C-Order: Last dimension moves fastest (Strides: large -> small -> 1)
def ravelmultiindex(*inp): # For shape (3, 4, 5): Multipliers are (20, 5, 1)
shape = inp[-1] if len(shape) > 1:
arr = np.stack(inp[:-1]) a[:-1] = np.cumprod(shape[:0:-1])[::-1]
else: # order == "F"
new_arr = arr.T.astype(np.float64).copy() # F-Order: First dimension moves fastest (Strides: 1 -> small -> large)
for i, b in enumerate(new_arr): # For shape (3, 4, 5): Multipliers are (1, 3, 12)
if b < 0 or b >= shape[i]: if len(shape) > 1:
mode_fn(new_arr, i, 0, b, shape[i]) a[1:] = np.cumprod(shape[:-1])
a = np.ones(len(shape), dtype=np.float64) # Dot product indices with strides
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1] # (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support)
return np.array(a.dot(new_arr.T), dtype=np.int64) return np.asarray((stacked_indices * a).sum(-1))
else:
@numba_basic.numba_njit cache_version = 1
def ravelmultiindex(*inp): return ravelmultiindex, cache_version
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
# no strict argument to this zip because numba doesn't support it
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
return ravelmultiindex
@register_funcify_default_op_cache_key(Repeat) @register_funcify_default_op_cache_key(Repeat)
......
...@@ -1371,8 +1371,7 @@ class RavelMultiIndex(Op): ...@@ -1371,8 +1371,7 @@ class RavelMultiIndex(Op):
self.order = order self.order = order
def make_node(self, *inp): def make_node(self, *inp):
multi_index = [ptb.as_tensor_variable(i) for i in inp[:-1]] *multi_index, dims = map(ptb.as_tensor_variable, inp)
dims = ptb.as_tensor_variable(inp[-1])
for i in multi_index: for i in multi_index:
if i.dtype not in int_dtypes: if i.dtype not in int_dtypes:
...@@ -1382,19 +1381,20 @@ class RavelMultiIndex(Op): ...@@ -1382,19 +1381,20 @@ class RavelMultiIndex(Op):
if dims.ndim != 1: if dims.ndim != 1:
raise TypeError("dims must be a 1D array") raise TypeError("dims must be a 1D array")
out_type = multi_index[0].type.clone(dtype="int64")
return Apply( return Apply(
self, self,
[*multi_index, dims], [*multi_index, dims],
[TensorType(dtype="int64", shape=(None,) * multi_index[0].type.ndim)()], [out_type()],
) )
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]] return [input_shapes[0]]
def perform(self, node, inp, out): def perform(self, node, inp, out):
multi_index, dims = inp[:-1], inp[-1] *multi_index, dims = inp
res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order) res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order)
out[0][0] = np.asarray(res, node.outputs[0].dtype) out[0][0] = np.asarray(res, "int64")
def ravel_multi_index(multi_index, dims, mode="raise", order="C"): def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
......
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.tensor import extra_ops from pytensor.tensor import extra_ops
from pytensor.tensor.extra_ops import RavelMultiIndex
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -133,35 +134,34 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -133,35 +134,34 @@ def test_FillDiagonalOffset(a, val, offset):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"arr, shape, mode, order, exc", "arr, shape, mode, exc",
[ [
( (
tuple((pt.lscalar(), v) for v in np.array([0])), tuple((pt.lscalar(), v) for v in np.array([0])),
(pt.lvector(), np.array([2])), (pt.lvector(), np.array([2])),
"raise", "raise",
"C",
None, None,
), ),
( (
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])), tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"C",
None, None,
), ),
( (
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])), tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"C",
None, None,
), ),
( (
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])), tuple(
(pt.lmatrix(), np.broadcast_to(v, (3, 2)).copy())
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"F", None,
NotImplementedError,
), ),
( (
tuple( tuple(
...@@ -169,7 +169,6 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -169,7 +169,6 @@ def test_FillDiagonalOffset(a, val, offset):
), ),
(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"C",
ValueError, ValueError,
), ),
( (
...@@ -178,7 +177,15 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -178,7 +177,15 @@ def test_FillDiagonalOffset(a, val, offset):
), ),
(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"wrap", "wrap",
"C", None,
),
(
tuple(
(pt.ltensor3(), np.broadcast_to(v, (2, 2, 3)).copy())
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
(pt.lvector(), np.array([2, 3, 4])),
"wrap",
None, None,
), ),
( (
...@@ -187,21 +194,30 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -187,21 +194,30 @@ def test_FillDiagonalOffset(a, val, offset):
), ),
(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"clip", "clip",
"C", None,
),
(
tuple(
(pt.lmatrix(), np.broadcast_to(v, (2, 3)).copy())
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
(pt.lvector(), np.array([2, 3, 4])),
"clip",
None, None,
), ),
], ],
) )
def test_RavelMultiIndex(arr, shape, mode, order, exc): def test_RavelMultiIndex(arr, shape, mode, exc):
arr, test_arr = zip(*arr, strict=True) arr, test_arr = zip(*arr, strict=True)
shape, test_shape = shape shape, test_shape = shape
g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape) g_c = RavelMultiIndex(mode, order="C")(*arr, shape)
g_f = RavelMultiIndex(mode, order="F")(*arr, shape)
cm = contextlib.suppress() if exc is None else pytest.raises(exc) cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
[*arr, shape], [*arr, shape],
g, [g_c, g_f],
[*test_arr, test_shape], [*test_arr, test_shape],
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论