提交 fc2b5e90 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

dimshuffle on a broadcastable dimension is handled

上级 9df89a90
...@@ -606,8 +606,22 @@ def local_dimshuffle_lift(node): ...@@ -606,8 +606,22 @@ def local_dimshuffle_lift(node):
return [ret] return [ret]
# remove useless dimshuffle in general # remove useless dimshuffle in general
if (list(op.new_order) == list(range(len(op.new_order))) and # covers two types of useless dimshuffle:
len(op.new_order) == input.type.ndim): # 1 - dimshuffle all dimensions in order
# 2 - dimshuffle a broadcastable dimension
is_useless = True
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (op.new_order[i] == i or
(i in all_broadcastable_dims and
op.new_order[i] in all_broadcastable_dims)):
continue
else:
is_useless = False
break
if is_useless:
return [input] return [input]
......
...@@ -202,6 +202,18 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -202,6 +202,18 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_dimshuffle_on_broadcastable(self):
x, y, z = inputs([False, True], [True, False, True], [False, False, True])
ds_x = ds(x, (0, 'x')) # useless
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # usefull
g = FunctionGraph([x, y, z], [ds_x, ds_y, ds_z])
self.assertTrue(str(g) == "[DimShuffle{0,x}(x), DimShuffle{2,1,0}(y), DimShuffle{2,1,0}(z)]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x, y, DimShuffle{2,1,0}(z)]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
n_segments = 10 n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论