提交 0de0fa9b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow broadcasting in specialized numba dispatch of AdvancedIncSubtensor

上级 52bbf59d
...@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(idx.type, TensorType) if isinstance(idx.type, TensorType)
] ]
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
# Special implementation for consecutive integer vector indices # Special implementation for consecutive integer vector indices
if ( if (
not basic_idxs not basic_idxs
...@@ -151,17 +142,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -151,17 +142,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
) )
# Must be consecutive # Must be consecutive
and not op.non_consecutive_adv_indexing(node) and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
or not broadcasted_to(
y.type.broadcastable,
(
x.type.broadcastable[: adv_idxs[0]["axis"]]
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
),
)
)
): ):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs) return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)
...@@ -191,14 +171,24 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -191,14 +171,24 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
return numba_funcify_default_subtensor(op, node, **kwargs) return numba_funcify_default_subtensor(op, node, **kwargs)
def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True
return False
def numba_funcify_multiple_integer_vector_indexing( def numba_funcify_multiple_integer_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
): ):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor) # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor): if isinstance(op, AdvancedSubtensor):
y, idxs = None, node.inputs[1:] idxs = node.inputs[1:]
else: else:
y, *idxs = node.inputs[1:] idxs = node.inputs[2:]
first_axis = next( first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType) i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
...@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
) )
except StopIteration: except StopIteration:
after_last_axis = len(idxs) after_last_axis = len(idxs)
last_axis = after_last_axis - 1
vector_indices = idxs[first_axis:after_last_axis]
assert all(v.type.broadcastable == (False,) for v in vector_indices)
if isinstance(op, AdvancedSubtensor): if isinstance(op, AdvancedSubtensor):
...@@ -231,9 +225,20 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -231,9 +225,20 @@ def numba_funcify_multiple_integer_vector_indexing(
return advanced_subtensor_multiple_vector return advanced_subtensor_multiple_vector
elif op.set_instead_of_inc: else:
inplace = op.inplace inplace = op.inplace
# Check if y must be broadcasted
# Includes the last integer vector index,
x, y = node.inputs[:2]
indexed_bcast_dims = (
*x.type.broadcastable[:first_axis],
*x.type.broadcastable[last_axis:],
)
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)
if op.set_instead_of_inc:
@numba_njit @numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs): def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis] vec_idxs = idxs[first_axis:after_last_axis]
...@@ -244,6 +249,9 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -244,6 +249,9 @@ def numba_funcify_multiple_integer_vector_indexing(
else: else:
out = x.copy() out = x.copy()
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]): for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)] out[(*outer, *scalar_idxs)] = y[(*outer, i)]
...@@ -252,7 +260,6 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -252,7 +260,6 @@ def numba_funcify_multiple_integer_vector_indexing(
return advanced_set_subtensor_multiple_vector return advanced_set_subtensor_multiple_vector
else: else:
inplace = op.inplace
@numba_njit @numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs): def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
...@@ -264,6 +271,9 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -264,6 +271,9 @@ def numba_funcify_multiple_integer_vector_indexing(
else: else:
out = x.copy() out = x.copy()
if y_is_broadcasted:
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]): for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)] out[(*outer, *scalar_idxs)] += y[(*outer, i)]
......
...@@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices):
np.array(-99), # Broadcasted value np.array(-99), # Broadcasted value
([1, 2], [2, 3]), # 2 vector indices ([1, 2], [2, 3]), # 2 vector indices
False, False,
True, False,
True, False,
), ),
( (
np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.arange(3 * 4 * 5).reshape((3, 4, 5)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论