提交 ec12c35c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Feature.clone method

This new method provides an interface for stateful `Feature`s to be easily cloned and attached to other `FunctionGraph`s.
上级 e865889c
......@@ -142,6 +142,9 @@ class Supervisor(Feature):
self.fgraph = None
self.protected = list(protected)
def clone(self):
return type(self)(self.protected)
def on_attach(self, fgraph):
if hasattr(fgraph, "_supervisor"):
raise AlreadyThere(f"A Supervisor is already attached to {fgraph}.")
......
......@@ -330,6 +330,9 @@ class DestroyHandler(Bookkeeper): # noqa
self.algo = algo
self.fail_validate = OrderedDict()
def clone(self):
return type(self)(self.do_imports_on_attach, self.algo)
def on_attach(self, fgraph):
"""
When attaching to a new fgraph, check that
......
......@@ -325,6 +325,17 @@ class Feature:
"""
return OrderedDict()
def clone(self):
"""Create a clone that can be attached to a new `FunctionGraph`.
This default implementation returns `self`, which carries the
assumption that the `Feature` is essentially stateless. If a subclass
has state of its own that is in any way relative to a given
`FunctionGraph`, this method should be overridden with an
implementation that actually creates a fresh copy.
"""
return self
class Bookkeeper(Feature):
def on_attach(self, fgraph):
......@@ -389,6 +400,9 @@ class History(Feature):
fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
def clone(self):
return type(self)()
def unpickle(self, fgraph):
fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph)
......@@ -516,6 +530,9 @@ class ReplaceValidate(History, Validator):
Validator.on_attach(self, fgraph)
self.unpickle(fgraph)
def clone(self):
return type(self)()
def unpickle(self, fgraph):
History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph)
......@@ -643,12 +660,17 @@ class NodeFinder(Bookkeeper):
def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present")
if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.")
self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph)
def clone(self):
return type(self)()
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
......@@ -751,6 +773,9 @@ class NoOutputFromInplace(Feature):
fgraph._no_output_from_inplace = self
def clone(self):
return type(self)(self.protected_out_ids)
def validate(self, fgraph):
if not hasattr(fgraph, "destroyers"):
return True
......
......@@ -900,7 +900,7 @@ class FunctionGraph(MetaObject):
if attach_feature:
for feature in self._features:
e.attach_feature(feature)
e.attach_feature(feature.clone())
return e, equiv
def __getstate__(self):
......
......@@ -534,6 +534,9 @@ class MergeFeature(Feature):
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def clone(self):
return type(self)()
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if node in self.nodes_seen:
# If inputs to a node change, it's not guaranteed that the node is
......@@ -2106,6 +2109,9 @@ class ChangeTracker(Feature):
self.changed = False
self.nb_imported = 0
def clone(self):
return type(self)()
def on_import(self, fgraph, node, reason):
self.nb_imported += 1
self.changed = True
......
......@@ -15,7 +15,6 @@ import aesara.scalar.basic as aes
from aesara import compile
from aesara.compile.ops import ViewOp
from aesara.configdefaults import config
from aesara.graph import features
from aesara.graph.basic import (
Constant,
Variable,
......@@ -23,6 +22,7 @@ from aesara.graph.basic import (
equal_computations,
io_toposort,
)
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.opt import (
......@@ -775,7 +775,7 @@ class MakeVectorPrinter(Printer):
pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(features.Feature):
class ShapeFeature(Feature):
"""Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with
......@@ -1230,7 +1230,7 @@ class ShapeFeature(features.Feature):
def on_attach(self, fgraph):
if hasattr(fgraph, "shape_feature"):
raise features.AlreadyThere("This FunctionGraph already has a ShapeFeature")
raise AlreadyThere("This FunctionGraph already has a ShapeFeature")
if hasattr(self, "fgraph") and self.fgraph != fgraph:
raise Exception("This ShapeFeature is already attached to a graph")
......@@ -1453,6 +1453,9 @@ class ShapeFeature(features.Feature):
return True
def clone(self):
return type(self)()
class ShapeOptimizer(GlobalOptimizer):
"""Optimizer that adds `ShapeFeature` as a feature."""
......@@ -3275,7 +3278,7 @@ class FusionOptimizer(GlobalOptimizer):
self.optimizer = local_optimizer
def add_requirements(self, fgraph):
fgraph.attach_feature(features.ReplaceValidate())
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
did_something = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论