提交 8ba5f98d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

A better implementation of the subtensor subtensor merge optimization

上级 613a642a
......@@ -1232,50 +1232,39 @@ def merge_two_slices(slice1, len1, slice2, len2):
# according to the two steps we have 4 different combinations of
# positive/negative. I will denote the case I'm looking at by
# suffixes to the variables (nn,np,pn,pp):
pp_start = sl1.start + sl2.start
pp_stop = sl1.start + sl2.stop
pp_start = sl1.start + sl2.start * sl1.step
pp_stop = sl1.start + sl2.stop * sl1.step
pp_step = sl1.step * sl2.step
pn_start = sl1.start + sl2.start
pn_stop = sl1.start + sl2.stop
pn_stop = sl1.start + sl2.start * sl1.step
pn_start = sl1.start + sl2.stop * sl1.step
pn_step = sl1.step * sl2.step * -1
pn_stop = T.switch(T.eq(pn_stop,-1), -len1 -1, pn_stop)
np_start = sl1.stop - sl2.stop
np_stop = sl1.stop - sl2.start
np_stop = sl1.stop - sl2.stop * sl1.step -1
np_start = sl1.stop - sl2.start * sl1.step -1
np_step = sl1.step * sl2.step * -1
np_stop = T.switch(T.eq(np_stop,-1), -len1 -1, np_stop)
nn_start = sl1.stop - sl2.start
nn_stop = sl1.stop - sl2.stop
nn_start = sl1.stop - sl2.start * sl1.step
nn_stop = sl1.stop - sl2.stop * sl1.step
nn_step = sl1.step * sl2.step
start = const_fold(T.switch(T.lt(reverse2*reverse1,0),
T.switch(T.lt(reverse1,0), np_start, pn_start),
T.switch(T.lt(reverse1,0), nn_start,
pp_start)).owner)[0]
if reverse1 is None and reverse2 is None:
start = pp_start
stop = pp_stop
step = pp_step
elif reverse1 is not None and reverse2 is None:
start = T.switch(lt(reverse1,0), np_start, pp_start)
stop = T.switch(lt(reverse1,0), np_stop , pp_stop )
step = T.switch(lt(reverse1,0), np_step , pp_step )
elif reverse1 is None and reverse2 is not None:
start = T.switch(lt(reverse2,0), pn_start, pp_start)
stop = T.switch(lt(reverse2,0), pn_stop , pp_stop )
step = T.switch(lt(reverse2,0), pn_step , pp_step )
else:
start = T.switch(lt(reverse2*reverse1,0),
T.switch(lt(reverse1,0), np_start, pn_start),
T.switch(lt(reverse1,0), nn_start, pp_start))
stop = T.switch(lt(reverse2*reverse1,0),
T.switch(lt(reverse1,0), np_stop , pn_stop ),
T.switch(lt(reverse1,0), nn_stop , pp_stop ))
step = T.switch(lt(reverse2*reverse1,0),
T.switch(lt(reverse1,0), np_step , pn_step ),
T.switch(lt(reverse1,0), nn_step , pp_step ))
stop = const_fold(T.switch(T.lt(reverse2*reverse1,0),
T.switch(T.lt(reverse1,0), np_stop , pn_stop ),
T.switch(T.lt(reverse1,0), nn_stop , pp_stop
)).owner)[0]
return slice(start,stop,step)
step = const_fold( T.switch(T.lt(reverse2*reverse1,0),
T.switch(T.lt(reverse1,0), np_step , pn_step ),
T.switch(T.lt(reverse1,0), nn_step , pp_step
)).owner)[0]
return slice(start, stop, step)
@register_canonicalize
@register_specialize
......@@ -1314,8 +1303,8 @@ def local_subtensor_merge(node):
if type(slice1) is slice:
merged_slices.append(
merge_two_slices(slice1,
slices2[pos_2],
xshape[pos_1],
slices2[pos_2],
ushape[pos_2]))
pos_2 += 1
else:
......@@ -1326,7 +1315,9 @@ def local_subtensor_merge(node):
sl_ins = T.Subtensor.collapse(
merged_slices,
lambda x: isinstance(x, T.Variable))
return [ subtens.make_node(node.inputs[0], *sl_ins).outputs[0]]
out = subtens.make_node(x, *sl_ins).outputs[0]
return [ out ]
@register_canonicalize
@gof.local_optimizer([None])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论