提交 8c157a25 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix local_fill_sink rewrite for multiple output Elemwise Ops

The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.
上级 d80c0bf7
......@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......@@ -320,56 +321,52 @@ def local_elemwise_alloc(fgraph, node):
return new_outs
@register_canonicalize("shape_unsafe")
@node_rewriter([Elemwise])
def local_fill_sink(fgraph, node):
"""
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
f need to be an elemwise that isn't a fill.
"""
if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill:
if isinstance(node.op.scalar_op, Second):
return False
models = []
inputs = []
for inp in node.inputs:
if inp.owner and inp.owner.op == fill:
models.append(inp.owner.inputs[0])
inputs.append(inp.owner.inputs[1])
a, b = inp.owner.inputs
if b.type.dtype != inp.dtype:
# The input was implicitly casted by the fill operation
b = b.cast(inp.dtype)
models.append(a)
inputs.append(b)
else:
inputs.append(inp)
if not models:
return False
c = node.op(*inputs)
for model in models:
if (
model.type.dtype != c.type.dtype
or model.type.broadcastable != c.type.broadcastable
):
c = fill(model, c)
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
# TODO: This should just be a WalkingGraphRewrite!
replacements = {node.outputs[0]: c}
for client, cl_idx in fgraph.clients[node.outputs[0]]:
if (
hasattr(client, "op")
and isinstance(client.op, Elemwise)
and client.op != fill
):
client_inputs = client.inputs[:]
client_inputs[cl_idx] = c
new_client = client.op(*client_inputs)
# Add clients to new_client
fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[
client.outputs[0]
]
r = local_fill_sink.transform(fgraph, new_client.owner)
if not r:
continue
replacements.update(r)
return replacements
outputs = node.op.make_node(*inputs).outputs
# Check if we need to propagate the fill to the new outputs
# It's enough to check the first output, as Elemwise outputs must all have the same shapes
# Note: There are orderings that may require fewer fills.
old_bcast_pattern = node.outputs[0].type.broadcastable
models_iter = iter(models)
while old_bcast_pattern != outputs[0].type.broadcastable:
model = next(models_iter)
# Only apply this model if it would actually do anything
if broadcasted_by(outputs[0], model):
outputs = [fill(model, output) for output in outputs]
return outputs
# The rewrite is wrapped in an in2out GraphRewriter
# so that fill can be sinked until the terminal nodes in a single pass through the graph
# without triggering other rewrites after each local substitution
topological_fill_sink = in2out(local_fill_sink)
register_canonicalize(topological_fill_sink, "shape_unsafe")
@register_specialize("shape_unsafe")
......
......@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.printing import debugprint, pprint
from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.scalar import Composite, float64
from pytensor.tensor.basic import (
Alloc,
Join,
......@@ -64,6 +65,7 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc,
local_useless_alloc,
local_useless_elemwise,
topological_fill_sink,
)
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
from pytensor.tensor.rewriting.shape import ShapeFeature
......@@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag():
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
with pytest.raises(ValueError):
fn([0, 1], [2, 3, 4]), [0, 1]
def test_topological_fill_sink_multi_output_client():
x = float64("x")
elem_op_with_2_outputs = Elemwise(Composite([x], [x + 1, x + 2]))
x = pt.vector("x", shape=(1,))
z = pt.vector("z", shape=(None,))
bcast_x = pt.full_like(z, x)
out = pt.add(*elem_op_with_2_outputs(pt.exp(bcast_x)))
fg = FunctionGraph([x, z], [out], copy_inputs=False)
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
assert equal_computations([new_out], [expected_out])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论