提交 24102962 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Test new cases for local_subtensor_merge, and fix one of them.

上级 3944c371
...@@ -1152,11 +1152,13 @@ def local_subtensor_merge(node): ...@@ -1152,11 +1152,13 @@ def local_subtensor_merge(node):
1) var[int:][-1] -> var[-1] # a little different for when the first subtensor is empty. 1) var[int:][-1] -> var[-1] # a little different for when the first subtensor is empty.
2) var[::-1][int] -> var[-int-1] 2) var[::-1][int] -> var[-int-1]
3) var[::-1][:int] -> var[:-int-1:-1] 3) var[::-1][:int] -> var[:-int-1:-1]
4) var[int1::][:int2] -> var[int1:switch(idx1>=0, 4) var[int1::][:int2] ->
idx1, var[int1:int2 + switch(int2<0,
maximum(u.owner.inputs[0].shape[0]+idx1, 0) 0,
) + idx2] switch(int1>=0,
int1,
maximum(u.owner.inputs[0].shape[0]+int1,
0))]
""" """
if (isinstance(node.op, T.Subtensor) and if (isinstance(node.op, T.Subtensor) and
len(node.op.idx_list)==1): len(node.op.idx_list)==1):
...@@ -1270,11 +1272,17 @@ def local_subtensor_merge(node): ...@@ -1270,11 +1272,17 @@ def local_subtensor_merge(node):
elif isinstance(idx2, int): elif isinstance(idx2, int):
idx2 = T.as_tensor_variable(idx2) idx2 = T.as_tensor_variable(idx2)
# The maximum is needed to don't have shape[0] - idx1 < 0 # Get positive version of idx1
idx2_neg = T.maximum(u.owner.inputs[0].shape[0]+idx1, 0) # TODO: use Razvan's code for that
new_idx2 = T.switch(idx1>=0, idx1, idx2_neg)+idx2 # The maximum is needed so that shape[0] + idx1 >= 0
neg_idx1 = T.maximum(u.owner.inputs[0].shape[0]+idx1, 0)
new_idx1 = T.switch((idx1 >= 0), idx1, neg_idx1)
# If idx2<0, we are indexing from the end, so idx2 is OK
# If we are indexing from the beginning, we need to add pos_idx1
new_idx2 = idx2 + T.switch((idx2 < 0), 0, new_idx1)
return [u.owner.inputs[0][idx1:new_idx2]] return [u.owner.inputs[0][new_idx1:new_idx2]]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论