提交 fe10f960 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix AdvancedSubtensor static shape with newaxis

上级 ae729683
...@@ -2629,9 +2629,13 @@ class AdvancedSubtensor(Op): ...@@ -2629,9 +2629,13 @@ class AdvancedSubtensor(Op):
advanced_indices = [] advanced_indices = []
adv_group_axis = None adv_group_axis = None
last_adv_group_axis = None last_adv_group_axis = None
expanded_x_shape = tuple( if new_axes:
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) expanded_x_shape_list = list(x.type.shape)
) for new_axis in new_axes:
expanded_x_shape_list.insert(new_axis, 1)
expanded_x_shape = tuple(expanded_x_shape_list)
else:
expanded_x_shape = x.type.shape
for i, (idx, dim_length) in enumerate( for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
): ):
......
...@@ -1856,6 +1856,7 @@ class TestAdvancedSubtensor: ...@@ -1856,6 +1856,7 @@ class TestAdvancedSubtensor:
assert x[idx1].type.shape == (10, None) assert x[idx1].type.shape == (10, None)
assert x[:, idx1].type.shape == (None, 10) assert x[:, idx1].type.shape == (None, 10)
assert x[None, :, idx1].type.shape == (1, None, 10)
assert x[idx2, :5].type.shape == (3, None, None) assert x[idx2, :5].type.shape == (3, None, None)
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5) assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3) assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论