提交 0b558d8e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Retain more precise types in MergeOptimizer

This can avoid some infinite rewrite loops where a SpecifyShape is lifted, removed and then reintroduced at the bottom by the MergeOptimizer
上级 4cc13bce
...@@ -743,14 +743,22 @@ class MergeOptimizer(GraphRewriter): ...@@ -743,14 +743,22 @@ class MergeOptimizer(GraphRewriter):
): ):
continue continue
if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type: # Keep the variable with the most specific static type from the pairs
res = pairs[0][0].type.convert_variable(pairs[0][1]) # E.g the second in (TensorType(shape=(None,), TensorType(shape=(5,))
# Otherwise we could end up reverting type inference progress done elsewhere.
# Since the fgraph.replace only checks the convert_variable for pair_idx in range(len(pairs)):
# in one way, we change the order in the case that old, new = pairs[pair_idx]
# convert_variable will not be successful. if old.type == new.type:
if not res: continue
pairs = [(pairs[0][1], pairs[0][0])] # Check if type of new replacement is at least as specific as that of the old variable
if not old.type.is_super(new.type):
# Check the other way around
if new.type.is_super(old.type):
pairs[pair_idx] = (new, old)
else:
# Replacement requires some operation like specify_shape
new_repl = old.type.convert_variable(new)
pairs[pair_idx] = (old, new_repl)
try: try:
# If they're all `AtomicVariable`s, there's no need to call validate. # If they're all `AtomicVariable`s, there's no need to call validate.
......
...@@ -21,10 +21,10 @@ from pytensor.graph.rewriting.basic import ( ...@@ -21,10 +21,10 @@ from pytensor.graph.rewriting.basic import (
pre_greedy_node_rewriter, pre_greedy_node_rewriter,
) )
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import ( from tests.graph.utils import (
MyOp, MyOp,
...@@ -441,6 +441,23 @@ class TestMergeOptimizer: ...@@ -441,6 +441,23 @@ class TestMergeOptimizer:
assert fg.outputs[0] is fg.outputs[1] assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2] assert fg.outputs[0] is not fg.outputs[2]
@pytest.mark.parametrize("reverse", [False, True])
def test_merge_more_specific_types(self, reverse):
"""Check that we choose the most specific static type when merging variables."""
x1 = vector("x1", shape=(None,))
x2 = vector("x2", shape=(500,))
y1 = exp(x1)
y2 = exp(x2)
# Simulate case where we find that x2 is equivalent to x1
fg = FunctionGraph([x1, x2], [y2, y1] if reverse else [y1, y2], clone=False)
fg.replace(x1, x2)
MergeOptimizer().rewrite(fg)
assert fg.outputs == [y2, y2]
class TestEquilibrium: class TestEquilibrium:
def test_1(self): def test_1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论