提交 0406b6fe authored 作者: Razvan Pascanu's avatar Razvan Pascanu

making the subtensor merge optimization work

上级 2c9fa876
......@@ -1802,7 +1802,9 @@ def local_subtensor_merge(node):
merged_slices = []
pos_2 = 0
for pos_1, slice1 in enumerate(slices1):
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if type(slice1) is slice:
merged_slices.append(
merge_two_slices(slice1,
......@@ -1812,8 +1814,14 @@ def local_subtensor_merge(node):
pos_2 += 1
else:
merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
merged_slices += slices2[pos_2:]
subtens = T.Subtensor(merged_slices)
sl_ins = T.Subtensor.collapse(
merged_slices,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论