提交 8c157a25 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix local_fill_sink rewrite for multiple output Elemwise Ops

The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.
上级 d80c0bf7
...@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import RewriteDatabase from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
...@@ -320,56 +321,52 @@ def local_elemwise_alloc(fgraph, node): ...@@ -320,56 +321,52 @@ def local_elemwise_alloc(fgraph, node):
return new_outs return new_outs
@register_canonicalize("shape_unsafe")
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_fill_sink(fgraph, node): def local_fill_sink(fgraph, node):
""" """
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
f need to be an elemwise that isn't a fill. f need to be an elemwise that isn't a fill.
""" """
if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: if isinstance(node.op.scalar_op, Second):
return False return False
models = [] models = []
inputs = [] inputs = []
for inp in node.inputs: for inp in node.inputs:
if inp.owner and inp.owner.op == fill: if inp.owner and inp.owner.op == fill:
models.append(inp.owner.inputs[0]) a, b = inp.owner.inputs
inputs.append(inp.owner.inputs[1]) if b.type.dtype != inp.dtype:
# The input was implicitly casted by the fill operation
b = b.cast(inp.dtype)
models.append(a)
inputs.append(b)
else: else:
inputs.append(inp) inputs.append(inp)
if not models: if not models:
return False return False
c = node.op(*inputs)
for model in models:
if (
model.type.dtype != c.type.dtype
or model.type.broadcastable != c.type.broadcastable
):
c = fill(model, c)
# The newly created node c doesn't has 'clients', outputs = node.op.make_node(*inputs).outputs
# so this iteration is took place with node.outputs[0]
# TODO: This should just be a WalkingGraphRewrite! # Check if we need to propagate the fill to the new outputs
replacements = {node.outputs[0]: c} # It's enough to check the first output, as Elemwise outputs must all have the same shapes
for client, cl_idx in fgraph.clients[node.outputs[0]]: # Note: There are orderings that may require fewer fills.
if ( old_bcast_pattern = node.outputs[0].type.broadcastable
hasattr(client, "op") models_iter = iter(models)
and isinstance(client.op, Elemwise) while old_bcast_pattern != outputs[0].type.broadcastable:
and client.op != fill model = next(models_iter)
): # Only apply this model if it would actually do anything
client_inputs = client.inputs[:] if broadcasted_by(outputs[0], model):
client_inputs[cl_idx] = c outputs = [fill(model, output) for output in outputs]
new_client = client.op(*client_inputs)
return outputs
# Add clients to new_client
fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[
client.outputs[0] # The rewrite is wrapped in an in2out GraphRewriter
] # so that fill can be sinked until the terminal nodes in a single pass through the graph
r = local_fill_sink.transform(fgraph, new_client.owner) # without triggering other rewrites after each local substitution
if not r: topological_fill_sink = in2out(local_fill_sink)
continue register_canonicalize(topological_fill_sink, "shape_unsafe")
replacements.update(r)
return replacements
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
......
...@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.printing import debugprint, pprint from pytensor.printing import debugprint, pprint
from pytensor.raise_op import Assert, CheckAndRaise from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.scalar import Composite, float64
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -64,6 +65,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -64,6 +65,7 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc, local_merge_alloc,
local_useless_alloc, local_useless_alloc,
local_useless_elemwise, local_useless_elemwise,
topological_fill_sink,
) )
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
...@@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag(): ...@@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag():
fn = function([x, y], out, mode=mode.excluding("shape_unsafe")) fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
with pytest.raises(ValueError): with pytest.raises(ValueError):
fn([0, 1], [2, 3, 4]), [0, 1] fn([0, 1], [2, 3, 4]), [0, 1]
def test_topological_fill_sink_multi_output_client():
x = float64("x")
elem_op_with_2_outputs = Elemwise(Composite([x], [x + 1, x + 2]))
x = pt.vector("x", shape=(1,))
z = pt.vector("z", shape=(None,))
bcast_x = pt.full_like(z, x)
out = pt.add(*elem_op_with_2_outputs(pt.exp(bcast_x)))
fg = FunctionGraph([x, z], [out], copy_inputs=False)
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
assert equal_computations([new_out], [expected_out])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论