提交 0406b6fe authored 作者: Razvan Pascanu's avatar Razvan Pascanu

making the subtensor merge optimization work

上级 2c9fa876
...@@ -1802,7 +1802,9 @@ def local_subtensor_merge(node): ...@@ -1802,7 +1802,9 @@ def local_subtensor_merge(node):
merged_slices = [] merged_slices = []
pos_2 = 0 pos_2 = 0
for pos_1, slice1 in enumerate(slices1): pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if type(slice1) is slice: if type(slice1) is slice:
merged_slices.append( merged_slices.append(
merge_two_slices(slice1, merge_two_slices(slice1,
...@@ -1812,8 +1814,14 @@ def local_subtensor_merge(node): ...@@ -1812,8 +1814,14 @@ def local_subtensor_merge(node):
pos_2 += 1 pos_2 += 1
else: else:
merged_slices.append(slice1) merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
merged_slices += slices2[pos_2:]
subtens = T.Subtensor(merged_slices) subtens = T.Subtensor(merged_slices)
sl_ins = T.Subtensor.collapse( sl_ins = T.Subtensor.collapse(
merged_slices, merged_slices,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论