提交 835d8444 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

special case when ndim is 0 is now handled

上级 fc2b5e90
......@@ -609,7 +609,7 @@ def local_dimshuffle_lift(node):
# covers two types of useless dimshuffle:
# 1 - dimshuffle all dimensions in order
# 2 - dimshuffle a broadcastable dimension
is_useless = True
is_useless = False
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
......@@ -617,7 +617,7 @@ def local_dimshuffle_lift(node):
if (op.new_order[i] == i or
(i in all_broadcastable_dims and
op.new_order[i] in all_broadcastable_dims)):
continue
is_useless = True
else:
is_useless = False
break
......
......@@ -204,13 +204,15 @@ class test_dimshuffle_lift(unittest.TestCase):
def test_dimshuffle_on_broadcastable(self):
x, y, z = inputs([False, True], [True, False, True], [False, False, True])
u = tensor.constant(1)
ds_x = ds(x, (0, 'x')) # useless
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # usefull
g = FunctionGraph([x, y, z], [ds_x, ds_y, ds_z])
self.assertTrue(str(g) == "[DimShuffle{0,x}(x), DimShuffle{2,1,0}(y), DimShuffle{2,1,0}(z)]")
ds_u = ds(u, ('x')) # usefull
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
self.assertTrue(str(g) == "[DimShuffle{0,x}(x), DimShuffle{2,1,0}(y), DimShuffle{2,1,0}(z), DimShuffle{x}(TensorConstant{1})]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x, y, DimShuffle{2,1,0}(z)]")
self.assertTrue(str(g) == "[x, y, DimShuffle{2,1,0}(z), DimShuffle{x}(TensorConstant{1})]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论