提交 0b558d8e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Retain more precise types in MergeOptimizer

This can avoid some infinite rewrite loops where a SpecifyShape is lifted, removed and then reintroduced at the bottom by the MergeOptimizer
上级 4cc13bce
......@@ -743,14 +743,22 @@ class MergeOptimizer(GraphRewriter):
):
continue
if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type:
res = pairs[0][0].type.convert_variable(pairs[0][1])
# Since the fgraph.replace only checks the convert_variable
# in one way, we change the order in the case that
# convert_variable will not be successful.
if not res:
pairs = [(pairs[0][1], pairs[0][0])]
# Keep the variable with the most specific static type from the pairs
# E.g the second in (TensorType(shape=(None,), TensorType(shape=(5,))
# Otherwise we could end up reverting type inference progress done elsewhere.
for pair_idx in range(len(pairs)):
old, new = pairs[pair_idx]
if old.type == new.type:
continue
# Check if type of new replacement is at least as specific as that of the old variable
if not old.type.is_super(new.type):
# Check the other way around
if new.type.is_super(old.type):
pairs[pair_idx] = (new, old)
else:
# Replacement requires some operation like specify_shape
new_repl = old.type.convert_variable(new)
pairs[pair_idx] = (old, new_repl)
try:
# If they're all `AtomicVariable`s, there's no need to call validate.
......
......@@ -21,10 +21,10 @@ from pytensor.graph.rewriting.basic import (
pre_greedy_node_rewriter,
)
from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot
from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import (
MyOp,
......@@ -441,6 +441,23 @@ class TestMergeOptimizer:
assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2]
@pytest.mark.parametrize("reverse", [False, True])
def test_merge_more_specific_types(self, reverse):
"""Check that we choose the most specific static type when merging variables."""
x1 = vector("x1", shape=(None,))
x2 = vector("x2", shape=(500,))
y1 = exp(x1)
y2 = exp(x2)
# Simulate case where we find that x2 is equivalent to x1
fg = FunctionGraph([x1, x2], [y2, y1] if reverse else [y1, y2], clone=False)
fg.replace(x1, x2)
MergeOptimizer().rewrite(fg)
assert fg.outputs == [y2, y2]
class TestEquilibrium:
def test_1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论