提交 0de0fa9b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow broadcasting in specialized numba dispatch of AdvancedIncSubtensor

上级 52bbf59d
......@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(idx.type, TensorType)
]
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
# Special implementation for consecutive integer vector indices
if (
not basic_idxs
......@@ -151,17 +142,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
)
# Must be consecutive
and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
or not broadcasted_to(
y.type.broadcastable,
(
x.type.broadcastable[: adv_idxs[0]["axis"]]
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
),
)
)
):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)
......@@ -191,14 +171,24 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
return numba_funcify_default_subtensor(op, node, **kwargs)
def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
def numba_funcify_multiple_integer_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor):
y, idxs = None, node.inputs[1:]
idxs = node.inputs[1:]
else:
y, *idxs = node.inputs[1:]
idxs = node.inputs[2:]
first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
......@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
)
except StopIteration:
after_last_axis = len(idxs)
last_axis = after_last_axis - 1
vector_indices = idxs[first_axis:after_last_axis]
assert all(v.type.broadcastable == (False,) for v in vector_indices)
if isinstance(op, AdvancedSubtensor):
......@@ -231,9 +225,20 @@ def numba_funcify_multiple_integer_vector_indexing(
return advanced_subtensor_multiple_vector
elif op.set_instead_of_inc:
else:
inplace = op.inplace
# Check if y must be broadcasted
# Includes the last integer vector index,
x, y = node.inputs[:2]
indexed_bcast_dims = (
*x.type.broadcastable[:first_axis],
*x.type.broadcastable[last_axis:],
)
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)
if op.set_instead_of_inc:
@numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
......@@ -244,6 +249,9 @@ def numba_funcify_multiple_integer_vector_indexing(
else:
out = x.copy()
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
......@@ -252,7 +260,6 @@ def numba_funcify_multiple_integer_vector_indexing(
return advanced_set_subtensor_multiple_vector
else:
inplace = op.inplace
@numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
......@@ -264,6 +271,9 @@ def numba_funcify_multiple_integer_vector_indexing(
else:
out = x.copy()
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
......
......@@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices):
np.array(-99), # Broadcasted value
([1, 2], [2, 3]), # 2 vector indices
False,
True,
True,
False,
False,
),
(
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论