提交 8e1cd563 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid no-op DimShuffle

上级 f72d7e58
...@@ -349,6 +349,9 @@ class _tensor_py_operators: ...@@ -349,6 +349,9 @@ class _tensor_py_operators:
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)): if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0] pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self) return ds_op(self)
def flatten(self, ndim=1): def flatten(self, ndim=1):
......
...@@ -950,7 +950,7 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -950,7 +950,7 @@ def test_Dimshuffle_lift_restrictions():
1e-7, 1e-7,
), ),
( (
(0, 1, 2), (0, 2, 1),
True, True,
normal, normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)), (np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
......
...@@ -148,7 +148,7 @@ class TestDimshuffleLift: ...@@ -148,7 +148,7 @@ class TestDimshuffleLift:
def test_useless_dimshuffle(self): def test_useless_dimshuffle(self):
x, *_ = inputs() x, *_ = inputs()
e = ds(x, (0, 1)) e = DimShuffle(new_order=(0, 1), input_ndim=2)(x)
g = FunctionGraph([x], [e], clone=False) g = FunctionGraph([x], [e], clone=False)
assert isinstance(g.outputs[0].owner.op, DimShuffle) assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论