Unverified 提交 d159f06d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix bug in AdvancedSubtensor infer_shape (#101)

* Fix bug in AdvancedSubtensor infer_shape The underlying utility `indexed_result_shape` was off by 1 in terms of when do the advanced index operations have to be brought to the front of the array.
上级 befc177d
...@@ -489,8 +489,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -489,8 +489,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape)) remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices) idx_groups = group_indices(indices)
if len(idx_groups) > 2 or len(idx_groups) > 1 and not idx_groups[0][0]: if len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0]):
# Bring adv. index groups to the front and merge each group # This means that there are at least two groups of advanced indexing separated by basic indexing
# In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0]) idx_groups = sorted(idx_groups, key=lambda x: x[0])
idx_groups = groupby( idx_groups = groupby(
chain.from_iterable(d_idx for _, d_idx in idx_groups), chain.from_iterable(d_idx for _, d_idx in idx_groups),
......
...@@ -2517,6 +2517,10 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True ...@@ -2517,6 +2517,10 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
), ),
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))), (np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])), (np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
(
np.arange(np.prod((5, 6, 7))).reshape((5, 6, 7)),
(slice(None, None), [1, 2, 3], slice(None, None)),
),
], ],
) )
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论