提交 4a6a10b1 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Fix bug in `PatternSub` when `get_nodes` and `values_eq_approx` are specified

上级 15c64971
...@@ -1692,9 +1692,6 @@ class PatternSub(LocalOptimizer): ...@@ -1692,9 +1692,6 @@ class PatternSub(LocalOptimizer):
continue continue
ret = self.transform(fgraph, real_node, get_nodes=False) ret = self.transform(fgraph, real_node, get_nodes=False)
if ret is not False and ret is not None: if ret is not False and ret is not None:
assert len(real_node.outputs) == len(ret)
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
return dict(zip(real_node.outputs, ret)) return dict(zip(real_node.outputs, ret))
if node.op != self.op: if node.op != self.op:
...@@ -1710,7 +1707,6 @@ class PatternSub(LocalOptimizer): ...@@ -1710,7 +1707,6 @@ class PatternSub(LocalOptimizer):
if expr_equiv is None: if expr_equiv is None:
return False return False
# TODO: Not sure how to handle multiple_clients flag # TODO: Not sure how to handle multiple_clients flag
# print 'retrying match', pattern, expr_equiv
return match( return match(
pattern, pattern,
expr_equiv, expr_equiv,
...@@ -1774,26 +1770,30 @@ class PatternSub(LocalOptimizer): ...@@ -1774,26 +1770,30 @@ class PatternSub(LocalOptimizer):
return u return u
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb) u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb)
if u: if not u:
return False
def build(pattern, u):
if isinstance(pattern, (list, tuple)): def build(pattern, u):
args = [build(p, u) for p in pattern[1:]] if isinstance(pattern, (list, tuple)):
return pattern[0](*args) args = [build(p, u) for p in pattern[1:]]
elif isinstance(pattern, str): return pattern[0](*args)
return u[unify.Var(pattern)] elif isinstance(pattern, str):
elif isinstance(pattern, (int, float)): return u[unify.Var(pattern)]
return pattern elif isinstance(pattern, (int, float)):
else: return pattern
return pattern.clone() else:
return pattern.clone()
p = self.out_pattern ret = build(self.out_pattern, u)
ret = build(p, u)
if self.values_eq_approx: if isinstance(ret, (int, float)):
ret.tag.values_eq_approx = self.values_eq_approx # TODO: Should we convert these to constants explicitly?
return [ret] return [ret]
else:
return False if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
return [ret]
def __str__(self): def __str__(self):
if getattr(self, "__name__", None): if getattr(self, "__name__", None):
......
...@@ -20,7 +20,7 @@ from aesara.graph.opt import ( ...@@ -20,7 +20,7 @@ from aesara.graph.opt import (
from aesara.tensor.basic_opt import constant_folding from aesara.tensor.basic_opt import constant_folding
from aesara.tensor.math import dot from aesara.tensor.math import dot
from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import matrix from aesara.tensor.type import matrix, values_eq_approx_always_true
from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import ( from tests.graph.utils import (
MyType, MyType,
...@@ -644,3 +644,36 @@ def test_pre_greedy_local_optimizer(): ...@@ -644,3 +644,36 @@ def test_pre_greedy_local_optimizer():
# Make sure constant of slice signature is hashable. # Make sure constant of slice signature is hashable.
assert isinstance(hash(cst.signature()), int) assert isinstance(hash(cst.signature()), int)
@pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
def test_patternsub_values_eq_approx(out_pattern, tracks):
# PatternSub would fail when `values_eq_approx` and `get_nodes` were specified
x = MyVariable("x")
e = op1(x)
fg = FunctionGraph([x], [e], clone=False)
opt = EquilibriumOptimizer(
[
PatternSub(
(op1, "x"),
out_pattern,
tracks=[op1] if tracks else (),
get_nodes=(lambda fgraph, node: [node]) if tracks else None,
values_eq_approx=values_eq_approx_always_true,
)
],
max_use_ratio=1,
)
opt.optimize(fg)
output = fg.outputs[0]
if isinstance(out_pattern, tuple):
assert output.owner.op == op2
assert output.tag.values_eq_approx is values_eq_approx_always_true
elif out_pattern == "x":
assert output is x
assert output.tag.values_eq_approx is values_eq_approx_always_true
else:
assert isinstance(output, Constant)
assert not hasattr(output.tag, "value_eq_approx")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论