提交 2fbbe138 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix the logic for where to put axes back

上级 ebaef5af
...@@ -511,40 +511,48 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor): ...@@ -511,40 +511,48 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
x = x.reshape(nshp) x = x.reshape(nshp)
narrays = 0
transp = list(range(x.ndim)) transp = list(range(x.ndim))
# number of array-indexed dimensions
p = 0 p = 0
# ap gives the position of the array in case there is only one. # ap represents the axis in the resulting array where the
# if there are more than one (narray > 1) it should be ignored. # dimensions indexed by arrays and ints will be inserted.
ap = 0 # For instance, if all such dimensions are grouped together,
# it corresponds to the index of the first such dimension in the
# inital array. If these dimensions are split (with slices
# inbetween), then the resulting dimensions will be moved to the
# beginning, and ap will be 0.
# If no such dimension has been encountered, ap is None.
ap = None
# Indicates whether we have already encountered an index (array
# or number), and then a slice.
slice_after_idx = False
for k, i in enumerate(list(nidx)): for k, i in enumerate(list(nidx)):
if (isinstance(i, np.ndarray) and if (isinstance(i, np.ndarray) and i.ndim != 0):
i.ndim != 0):
transp.remove(k) transp.remove(k)
transp.insert(p, k) transp.insert(p, k)
ap += k
i = nidx.pop(k) i = nidx.pop(k)
nidx.insert(p, i) nidx.insert(p, i)
p += 1 p += 1
narrays += 1 if ap is None:
# first non-slice index
ap = k
elif slice_after_idx:
# We already encountered at least an array or int, and then
# a slice. Array-indexed axes are not grouped,
# moving to the beginning
ap = 0
else: else:
if narrays == 0:
try: try:
i.__index__() i.__index__()
# We shift back the position of the array by the if ap is None:
# number of dimensions that are removed by ap = k
# indexing. If ap is bigger than 0 it means we # indices do not break the contiguity of
# have encountered at least one array. # array-indexed axes
if ap >= 0:
ap -= 1
# If this index is before the first array then
# we will not move the array back to its
# position. Mark this by faking that there
# are more than two arrays. This is crazy
# numpy behaviour so blame them.
narrays = 2
except Exception: except Exception:
pass # If we already encountered an array/int index, it
# means future ones will not be grouped.
if ap is not None:
slice_after_idx = True
x = x.transpose(*transp) x = x.transpose(*transp)
...@@ -552,12 +560,16 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor): ...@@ -552,12 +560,16 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
x = x.__getitem__(idx_) x = x.__getitem__(idx_)
if p == 0: if p == 0:
assert ap is None
# The only indexing was through slices and indices. # The only indexing was through slices and indices.
# This can happen with symbolic slices for instance. # This can happen with symbolic slices for instance.
# Since no view_map is set, we need to copy the returned value # Since no view_map is set, we need to copy the returned value
out[0] = x.copy() out[0] = x.copy()
return return
# At this point, we should have encountered at least one array
assert ap is not None
# flatten the array-indexed dimensions # flatten the array-indexed dimensions
shape = ((np.prod(x.shape[0: p]),) + shape = ((np.prod(x.shape[0: p]),) +
x.shape[p:]) x.shape[p:])
...@@ -578,10 +590,9 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor): ...@@ -578,10 +590,9 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
out_flat_shp = take_idx.shape + x.shape[p:] out_flat_shp = take_idx.shape + x.shape[p:]
o = out_flat.reshape(out_flat_shp) o = out_flat.reshape(out_flat_shp)
# If there was only one array we need to move the indexed if ap != 0:
# dimension(s) back to the position of the array, which is # Put the resulting indexing at the place that NumPy
# stored in ap. Note that ap is invalid is narrays != 1. # decided was the right one.
if narrays == 1:
ntransp = list(range(take_idx.ndim, o.ndim)) ntransp = list(range(take_idx.ndim, o.ndim))
ntransp[ap:ap] = list(range(take_idx.ndim)) ntransp[ap:ap] = list(range(take_idx.ndim))
o = o.transpose(*ntransp) o = o.transpose(*ntransp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论