require -> require_set, __require__ -> __env_require__

上级 a78dff92
...@@ -53,7 +53,22 @@ class InconsistencyError(GofError): ...@@ -53,7 +53,22 @@ class InconsistencyError(GofError):
""" """
pass pass
def require_set(cls):
"""Return the set of objects named in a __env_require__ field in a base class"""
r = set()
if hasattr(cls, '__class__'):
cls = cls.__class__
bases = utils.all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
for base in bases:
req = base.__env_require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
class Env(graph.Graph): class Env(graph.Graph):
""" """
...@@ -179,7 +194,7 @@ class Env(graph.Graph): ...@@ -179,7 +194,7 @@ class Env(graph.Graph):
return True return True
def satisfy(self, x): def satisfy(self, x):
for feature_class in x.require(): for feature_class in require_set(x):
self.add_feature(feature_class) self.add_feature(feature_class)
def add_feature(self, feature_class, do_import = True): def add_feature(self, feature_class, do_import = True):
......
...@@ -58,6 +58,21 @@ def compute(*nodes): ...@@ -58,6 +58,21 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way).""" """Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set()) compute_from(nodes, set())
def root_inputs(input):
"""Return the leaves of a search through consecutive view_map()s"""
owner = input.owner
if owner:
view_map = owner.view_map()
if input in view_map:
answer = []
for input2 in view_map[input]:
answer.append(root_inputs(input2))
return answer
else:
return [input]
else:
return [input]
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env): def __init__(self, env):
...@@ -346,7 +361,7 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings) ...@@ -346,7 +361,7 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
class NewPythonOp(Op): class NewPythonOp(Op):
__require__ = DestroyHandler __env_require__ = DestroyHandler
def view_map(self): def view_map(self):
return {} return {}
...@@ -517,7 +532,7 @@ class PythonOp(NewPythonOp): ...@@ -517,7 +532,7 @@ class PythonOp(NewPythonOp):
if output not in except_list: if output not in except_list:
output.alloc() output.alloc()
__require__ = ForbidConstantOverwrite __env_require__ = ForbidConstantOverwrite
def __copy__(self): def __copy__(self):
""" """
......
...@@ -28,7 +28,6 @@ class Op(object): ...@@ -28,7 +28,6 @@ class Op(object):
""" """
__slots__ = ['_inputs', '_outputs'] __slots__ = ['_inputs', '_outputs']
__require__ = []
#create inputs and outputs as read-only attributes #create inputs and outputs as read-only attributes
inputs = property(lambda self: self._inputs, doc = "The list of this Op's input Results.") inputs = property(lambda self: self._inputs, doc = "The list of this Op's input Results.")
...@@ -229,11 +228,10 @@ class Op(object): ...@@ -229,11 +228,10 @@ class Op(object):
""" """
r = set() r = set()
bases = all_bases(cls, lambda cls: hasattr(cls, '__require__')) bases = all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
bases.append(cls)
for base in bases: for base in bases:
req = base.__require__ req = base.__env_require__
if isinstance(req, (list, tuple)): if isinstance(req, (list, tuple)):
r.update(req) r.update(req)
else: else:
......
...@@ -9,8 +9,6 @@ import ext ...@@ -9,8 +9,6 @@ import ext
class Optimizer: class Optimizer:
__require__ = ()
def apply(self, env): def apply(self, env):
pass pass
...@@ -18,27 +16,6 @@ class Optimizer: ...@@ -18,27 +16,6 @@ class Optimizer:
env.satisfy(self) env.satisfy(self)
self.apply(env) self.apply(env)
@classmethod
def require(cls):
"""
Returns a list of EnvFeature subclasses that must be used by
any Env manipulating this kind of op. For instance, a
Destroyer requires features.DestroyHandler to guarantee that
various destructive operations don't interfere.
"""
r = set()
bases = utils.all_bases(cls, lambda cls: hasattr(cls, '__require__'))
bases.append(cls)
for base in bases:
req = base.__require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
def __call__(self, env): def __call__(self, env):
self.optimize(env) self.optimize(env)
...@@ -88,7 +65,7 @@ class LocalOptimizer(Optimizer): ...@@ -88,7 +65,7 @@ class LocalOptimizer(Optimizer):
class OpSpecificOptimizer(LocalOptimizer): class OpSpecificOptimizer(LocalOptimizer):
__require__ = features.InstanceFinder __env_require__ = features.InstanceFinder
opclass = Op opclass = Op
...@@ -100,7 +77,7 @@ class OpSpecificOptimizer(LocalOptimizer): ...@@ -100,7 +77,7 @@ class OpSpecificOptimizer(LocalOptimizer):
class OpSubOptimizer(Optimizer): class OpSubOptimizer(Optimizer):
__require__ = features.InstanceFinder __env_require__ = features.InstanceFinder
def __init__(self, op1, op2): def __init__(self, op1, op2):
if not op1.has_default_output: if not op1.has_default_output:
...@@ -127,7 +104,7 @@ class OpSubOptimizer(Optimizer): ...@@ -127,7 +104,7 @@ class OpSubOptimizer(Optimizer):
class OpRemover(Optimizer): class OpRemover(Optimizer):
__require__ = features.InstanceFinder __env_require__ = features.InstanceFinder
def __init__(self, opclass): def __init__(self, opclass):
self.opclass = opclass self.opclass = opclass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论