提交 fc21336a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow fill_sink rewrite to accomodate changes in broadcastability

上级 a6255d69
...@@ -351,10 +351,7 @@ def local_fill_sink(fgraph, node): ...@@ -351,10 +351,7 @@ def local_fill_sink(fgraph, node):
# Check if we need to propagate the fill to the new outputs # Check if we need to propagate the fill to the new outputs
# It's enough to check the first output, as Elemwise outputs must all have the same shapes # It's enough to check the first output, as Elemwise outputs must all have the same shapes
# Note: There are orderings that may require fewer fills. # Note: There are orderings that may require fewer fills.
old_bcast_pattern = node.outputs[0].type.broadcastable for model in models:
models_iter = iter(models)
while old_bcast_pattern != outputs[0].type.broadcastable:
model = next(models_iter)
# Only apply this model if it would actually do anything # Only apply this model if it would actually do anything
if broadcasted_by(outputs[0], model): if broadcasted_by(outputs[0], model):
outputs = [fill(model, output) for output in outputs] outputs = [fill(model, output) for output in outputs]
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import pytensor import pytensor
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import shared from pytensor import graph_replace, shared
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.mode import get_default_mode, get_mode
...@@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client(): ...@@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client():
[new_out] = fg.outputs [new_out] = fg.outputs
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x)))) expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
assert equal_computations([new_out], [expected_out]) assert equal_computations([new_out], [expected_out])
def test_topological_fill_sink_broadcastable_change():
"""Test rewrite doesn't fail after a graph replacement that provides a broadcastable change."""
a = vector("a", shape=(1,))
b = vector("b", shape=(1,))
zeros = pt.vector("zeros", shape=(None,))
initial_out = pt.full_like(zeros, a) + b
# Make broadcast to zeros irrelevant
out = graph_replace(initial_out, {zeros: pt.zeros((1,))}, strict=False)
fg = FunctionGraph([a, b], [out], copy_inputs=False)
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
assert equal_computations([new_out], [a + b])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论