提交 652d0b69 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in local_dimshuffle_lift when elemwise has multiple outputs

上级 712660e7
...@@ -422,7 +422,12 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -422,7 +422,12 @@ def local_dimshuffle_lift(fgraph, node):
inp = node.inputs[0] inp = node.inputs[0]
inode = inp.owner inode = inp.owner
new_order = op.new_order new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1): if (
inode
and isinstance(inode.op, Elemwise)
and len(inode.outputs) == 1
and (len(fgraph.clients[inp]) == 1)
):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
new_inputs = [] new_inputs = []
for inp in inode.inputs: for inp in inode.inputs:
......
...@@ -19,7 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -19,7 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite, float64
from pytensor.tensor.basic import MakeVector from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs from pytensor.tensor.math import abs as at_abs
...@@ -163,6 +163,20 @@ class TestDimshuffleLift: ...@@ -163,6 +163,20 @@ class TestDimshuffleLift:
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace") assert hasattr(g.outputs[0].tag, "trace")
def test_dimshuffle_lift_multi_out_elemwise(self):
# Create a multi-output Elemwise Op with Composite
x = float64("x")
outs = [x + 1, x + 2]
op = Elemwise(Composite([x], outs))
# Transpose both outputs
x = matrix("x")
outs = [out.T for out in op(x)]
# Make sure rewrite doesn't apply in this case
g = FunctionGraph([x], outs)
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
def test_local_useless_dimshuffle_in_reshape(): def test_local_useless_dimshuffle_in_reshape():
vec = TensorType(dtype="float64", shape=(None,))("vector") vec = TensorType(dtype="float64", shape=(None,))("vector")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论