提交 652d0b69 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in local_dimshuffle_lift when elemwise has multiple outputs

上级 712660e7
......@@ -422,7 +422,12 @@ def local_dimshuffle_lift(fgraph, node):
inp = node.inputs[0]
inode = inp.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1):
if (
inode
and isinstance(inode.op, Elemwise)
and len(inode.outputs) == 1
and (len(fgraph.clients[inp]) == 1)
):
# Don't use make_node to have tag.test_value set.
new_inputs = []
for inp in inode.inputs:
......
......@@ -19,7 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import Composite
from pytensor.scalar.basic import Composite, float64
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs
......@@ -163,6 +163,20 @@ class TestDimshuffleLift:
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
def test_dimshuffle_lift_multi_out_elemwise(self):
# Create a multi-output Elemwise Op with Composite
x = float64("x")
outs = [x + 1, x + 2]
op = Elemwise(Composite([x], outs))
# Transpose both outputs
x = matrix("x")
outs = [out.T for out in op(x)]
# Make sure rewrite doesn't apply in this case
g = FunctionGraph([x], outs)
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
def test_local_useless_dimshuffle_in_reshape():
vec = TensorType(dtype="float64", shape=(None,))("vector")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论