提交 623b4175 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed small bug in destroying orphans

上级 0f344ee4
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
from result import ResultBase from result import ResultBase
from op import Op from op import Op
from ext import Destroyer
from opt import * from opt import *
from env import Env from env import Env
from toolbox import * from toolbox import *
...@@ -47,6 +48,10 @@ class Op3(MyOp): ...@@ -47,6 +48,10 @@ class Op3(MyOp):
class Op4(MyOp): class Op4(MyOp):
pass pass
class OpD(MyOp, Destroyer):
def destroyed_inputs(self):
return [self.inputs[0]]
import modes import modes
modes.make_constructors(globals()) modes.make_constructors(globals())
...@@ -175,6 +180,7 @@ class _test_ConstantFinder(unittest.TestCase): ...@@ -175,6 +180,7 @@ class _test_ConstantFinder(unittest.TestCase):
e = op1(x, y, z) e = op1(x, y, z)
g = env([x], [e]) g = env([x], [e])
ConstantFinder().optimize(g) ConstantFinder().optimize(g)
assert y.constant and z.constant
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(x, y, y)]" \ assert str(g) == "[Op1(x, y, y)]" \
or str(g) == "[Op1(x, z, z)]" or str(g) == "[Op1(x, z, z)]"
...@@ -186,10 +192,22 @@ class _test_ConstantFinder(unittest.TestCase): ...@@ -186,10 +192,22 @@ class _test_ConstantFinder(unittest.TestCase):
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x], [e]) g = env([x], [e])
ConstantFinder().optimize(g) ConstantFinder().optimize(g)
assert y.constant and z.constant
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]" or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_2(self):
x, y, z = inputs()
y.data = 2
z.data = 2
e = op_d(x, op2(y, z))
g = env([y], [e])
ConstantFinder().optimize(g)
assert not getattr(x, 'constant', False) and z.constant
MergeOptimizer().optimize(g)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -25,15 +25,16 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -25,15 +25,16 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.illegal = set() self.illegal = set()
self.env = env self.env = env
self.seen = set() self.seen = set()
for input in env.inputs: for input in env.orphans().union(env.inputs):
self.children[input] = set() self.children[input] = set()
def publish(self): def publish(self):
def __destroyers(): def __destroyers(r):
ret = self.destroyers.get(foundation, set()) ret = self.destroyers.get(r, {})
ret = ret.keys() ret = ret.keys()
return ret return ret
self.env.destroyers = __destroyers self.env.destroyers = __destroyers
self.env.destroy_handler = self
def __path__(self, r): def __path__(self, r):
path = self.paths.get(r, None) path = self.paths.get(r, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论