PatternOptimizer can now take ResultBase instances marked as constant in its patterns

上级 b2b0d215
......@@ -114,6 +114,18 @@ class _test_PatternOptimizer(unittest.TestCase):
'1').optimize(g)
assert str(g) == "[Op1(x)]"
def test_6(self):
x, y, z = inputs()
x.constant = True
x.value = 2
z.constant = True
z.value = 2
e = op1(op1(x, y), y)
g = env([y], [e])
PatternOptimizer((Op1, z, '1'),
(Op2, '1', z)).optimize(g)
assert str(g) == "[Op1(Op2(y, z), y)]"
class _test_OpSubOptimizer(unittest.TestCase):
......
from op import Op
from result import ResultBase
from env import InconsistencyError
import utils
import unify
......@@ -140,10 +141,14 @@ class PatternOptimizer(OpSpecificOptimizer):
return False
else:
u = u.merge(expr, v)
else:
if pattern != expr:
return False
elif isinstance(pattern, ResultBase) \
and getattr(pattern, 'constant', False) \
and isinstance(expr, ResultBase) \
and getattr(expr, 'constant', False) \
and pattern.hash() == expr.hash():
return u
else:
return False
return u
def build(pattern, u):
......@@ -204,8 +209,12 @@ class MergeOptimizer(Optimizer):
for i, r in enumerate(env.orphans().union(env.inputs)):
if getattr(r, 'constant', False) and hasattr(r, 'hash'):
ref = ('const', r.hash())
cid[r] = ref
inv_cid[ref] = r
other_r = inv_cid.get(ref, None)
if other_r is not None:
env.replace(r, other_r)
else:
cid[r] = ref
inv_cid[ref] = r
else:
cid[r] = i
inv_cid[i] = r
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论