提交 7a45b553 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added a specialized function to deal only with merging two slices.

上级 98eb9b65
...@@ -1153,6 +1153,116 @@ def local_subtensor_lift(node): ...@@ -1153,6 +1153,116 @@ def local_subtensor_lift(node):
new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim)) new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)] return [u.owner.op(*new_inputs)]
def merge_two_slices(slice1, len1, slice2, len2):
'''
This function merges two slices into a single slice. The code works on
the assumption that:
a) slice1 is actually a slice and not an index, while slice2
can be just an index.
b) the two slices **have been applied consecutively** on the same
tensor
The output slice is **not** in canonical form, but actually just a slice
that can be applied to a tensor to produce the same output as applying
the two consecutive slices.
``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice.
'''
if type(slice1) is not slice:
raise ValueError( ('First provided slice should actually be of type'
'slice and not an index !'),slice1)
sl1, reverse1 = T.get_canonical_form_slice(slice1, len1)
sl2, reverse2 = T.get_canonical_form_slice(slice2, len2)
if type(sl2) is not slice:
if reverse1 is None:
# The first slice is not in reverse, which makes things a lot
# more clear.
# In this case we need to take care only of the special cases:
# len2 <=0 -> throw index error regardless of sl2
# sl2 > len2 -> throw index error
# sl2 < -len2 -> throw index error
# To get a index error we simply use len1+1 to indicate we are
# out of bounds, because passing this index through the formula
# of getting the mixed slice is not guaranteed to result in an
# index error. The **issue though** if that the error will
# complain about accessing element len1+1 which is probably not
# too intuitive for the user
val = sl1.start + sl2*sl1.step
val = T.switch(T.le(len2,0 ), len1+1, val)
val = T.switch(T.ge(sl2 ,len2), len1+1, val)
val = T.switch(T.lt(sl2, 0 ), -len1-1, val)
if sl1.step:
val = T.switch(T.eq(sl1.step,0), len1+1, val)
return val
else:
# We are in the more complex case when we do not actually know
# if the first slice was in reverse or not.
# in case it was not in reverse:
p_val = sl1.start + sl2*sl1.step
# case it was in reverse we need to realize that we do not want
# the k-th element from sl.start but the k-th element from
# sl.stop backwards
n_val = sl1.stop - sl1.start - 1 - sl2*sl1.step
# we need to pick either n_val or p_val and then follow same
# steps as above for covering the index error cases
val = T.switch(T.lt(reverse1,0), n_val, p_val)
val = T.switch(T.le(len2,0 ), len1+1, val)
val = T.switch(T.ge(sl2 ,len2), len1+1, val)
val = T.switch(T.lt(sl2, 0 ), -len1-1, val)
if sl1.step:
val = T.switch(T.eq(sl1.step,0), len1+1, val)
return val
else:
# We are deleaing with two slices that need to be put together
# according to the two steps we have 4 different combinations of
# positive/negative. I will denote the case I'm looking at by
# suffixes to the variables (nn,np,pn,pp):
pp_start = sl1.start + sl2.start
pp_stop = sl1.start + sl2.stop
pp_step = sl1.step * sl2.step
pn_start = sl1.start + sl2.start
pn_stop = sl1.start + sl2.stop
pn_step = sl1.step * sl2.step * -1
np_start = sl1.stop - sl2.stop
np_stop = sl1.stop - sl2.start
np_step = sl1.step * sl2.step * -1
nn_start = sl1.stop - sl2.start
nn_stop = sl1.stop - sl2.stop
nn_step = sl1.step * sl2.step
if reverse1 is None and reverse2 is None:
start = pp_start
stop = pp_stop
step = pp_step
elif reverse1 is not None and reverse2 is None:
start = T.switch(lt(reverse1,0), np_start, pp_start)
stop = T.switch(lt(reverse1,0), np_stop , pp_stop )
step = T.switch(lt(reverse1,0), np_step , pp_step )
elif reverse1 is None and reverse2 is not None:
start = T.switch(lt(reverse2,0), pn_start, pp_start)
stop = T.switch(lt(reverse2,0), pn_stop , pp_stop )
step = T.switch(lt(reverse2,0), pn_step , pp_step )
else:
start = T.switch(lt(reverse2*reverse1,0),
T.switch(lt(reverse1,0), np_start, pn_start),
T.switch(lt(reverse1,0), nn_start, pp_start))
stop = T.switch(lt(reverse2*reverse1,0),
T.switch(lt(reverse1,0), np_stop , pn_stop ),
T.switch(lt(reverse1,0), nn_stop , pp_stop ))
step = T.switch(lt(reverse2*reverse1,0),
T.switch(lt(reverse1,0), np_step , pn_step ),
T.switch(lt(reverse1,0), nn_step , pp_step ))
return slice(start,stop,step)
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论