提交 8e1cd563 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid no-op DimShuffle

上级 f72d7e58
......@@ -349,6 +349,9 @@ class _tensor_py_operators:
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self)
def flatten(self, ndim=1):
......
......@@ -950,7 +950,7 @@ def test_Dimshuffle_lift_restrictions():
1e-7,
),
(
(0, 1, 2),
(0, 2, 1),
True,
normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
......
......@@ -148,7 +148,7 @@ class TestDimshuffleLift:
def test_useless_dimshuffle(self):
x, *_ = inputs()
e = ds(x, (0, 1))
e = DimShuffle(new_order=(0, 1), input_ndim=2)(x)
g = FunctionGraph([x], [e], clone=False)
assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论