made an Optimizer to mark undestroyed orphans constant

上级 d7d8792a
...@@ -79,6 +79,15 @@ class BaseTensor(ResultBase): ...@@ -79,6 +79,15 @@ class BaseTensor(ResultBase):
#TODO: add more type correspondances for e.g. int32, int64, float32, #TODO: add more type correspondances for e.g. int32, int64, float32,
#complex64, etc. #complex64, etc.
return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype] return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype]
#
# Hash for constant folding
#
def hash(self):
if self.data:
return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:])
else:
return (BaseTensor, self.dtype, self.broadcastable, None)
# #
# C codegen stubs # C codegen stubs
......
...@@ -20,6 +20,9 @@ class MyResult(ResultBase): ...@@ -20,6 +20,9 @@ class MyResult(ResultBase):
def __repr__(self): def __repr__(self):
return self.name return self.name
def hash(self):
return self.data
class MyOp(Op): class MyOp(Op):
...@@ -129,6 +132,39 @@ class _test_OpSubOptimizer(unittest.TestCase): ...@@ -129,6 +132,39 @@ class _test_OpSubOptimizer(unittest.TestCase):
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]" assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
class _test_MergeOptimizer(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
def test_1(self):
x, y, z = inputs()
y.data = 2
y.constant = True
z.data = 2.0
z.constant = True
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_2(self):
x, y, z = inputs()
y.data = 2
z.data = 2
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x], [e])
ConstantFinder().optimize(g)
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
from features import Listener, Constraint, Orderings from features import Listener, Constraint, Orderings, Tool
from utils import AbstractFunctionError from utils import AbstractFunctionError
from copy import copy from copy import copy
...@@ -13,7 +13,7 @@ __all__ = ['Destroyer', ...@@ -13,7 +13,7 @@ __all__ = ['Destroyer',
class DestroyHandler(Listener, Constraint, Orderings): class DestroyHandler(Listener, Constraint, Orderings, Tool):
def __init__(self, env): def __init__(self, env):
self.parent = {} self.parent = {}
...@@ -28,6 +28,13 @@ class DestroyHandler(Listener, Constraint, Orderings): ...@@ -28,6 +28,13 @@ class DestroyHandler(Listener, Constraint, Orderings):
for input in env.inputs: for input in env.inputs:
self.children[input] = set() self.children[input] = set()
def publish(self):
def __destroyers():
ret = self.destroyers.get(foundation, set())
ret = ret.keys()
return ret
self.env.destroyers = __destroyers
def __path__(self, r): def __path__(self, r):
path = self.paths.get(r, None) path = self.paths.get(r, None)
if path: if path:
......
...@@ -83,8 +83,6 @@ class OpSubOptimizer(Optimizer): ...@@ -83,8 +83,6 @@ class OpSubOptimizer(Optimizer):
try: try:
# note: only replaces the default 'out' port if it exists # note: only replaces the default 'out' port if it exists
r = self.op2(*op.inputs).out r = self.op2(*op.inputs).out
# if isinstance(r, Op):
# r = r.out
env.replace(op.out, r) env.replace(op.out, r)
except InconsistencyError, e: except InconsistencyError, e:
# print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug # print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
...@@ -179,15 +177,38 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -179,15 +177,38 @@ class PatternOptimizer(OpSpecificOptimizer):
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern)) return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
class ConstantFinder(Optimizer):
def apply(self, env):
if env.has_feature(ext.DestroyHandler):
for r in env.orphans():
if not env.destroyers(r):
r.indestructible = True
r.constant = True
for r in env.inputs:
if not env.destroyers(r):
r.indestructible = True
else:
for r in env.orphans():
r.indestructible = True
r.constant = True
for r in env.inputs:
r.indestructible = True
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
def apply(self, env): def apply(self, env):
cid = {} cid = {}
inv_cid = {} inv_cid = {}
for i, r in enumerate(env.inputs.union(env.orphans())): for i, r in enumerate(env.orphans().union(env.inputs)):
cid[r] = i if getattr(r, 'constant', False) and hasattr(r, 'hash'):
inv_cid[i] = r ref = ('const', r.hash())
cid[r] = ref
inv_cid[ref] = r
else:
cid[r] = i
inv_cid[i] = r
for op in env.io_toposort(): for op in env.io_toposort():
op_cid = (op.__class__, tuple([cid[input] for input in op.inputs])) op_cid = (op.__class__, tuple([cid[input] for input in op.inputs]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论