提交 ec12c35c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Feature.clone method

This new method provides an interface for stateful `Feature`s to be easily cloned and attached to other `FunctionGraph`s.
上级 e865889c
...@@ -142,6 +142,9 @@ class Supervisor(Feature): ...@@ -142,6 +142,9 @@ class Supervisor(Feature):
self.fgraph = None self.fgraph = None
self.protected = list(protected) self.protected = list(protected)
def clone(self):
return type(self)(self.protected)
def on_attach(self, fgraph): def on_attach(self, fgraph):
if hasattr(fgraph, "_supervisor"): if hasattr(fgraph, "_supervisor"):
raise AlreadyThere(f"A Supervisor is already attached to {fgraph}.") raise AlreadyThere(f"A Supervisor is already attached to {fgraph}.")
......
...@@ -330,6 +330,9 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -330,6 +330,9 @@ class DestroyHandler(Bookkeeper): # noqa
self.algo = algo self.algo = algo
self.fail_validate = OrderedDict() self.fail_validate = OrderedDict()
def clone(self):
return type(self)(self.do_imports_on_attach, self.algo)
def on_attach(self, fgraph): def on_attach(self, fgraph):
""" """
When attaching to a new fgraph, check that When attaching to a new fgraph, check that
......
...@@ -325,6 +325,17 @@ class Feature: ...@@ -325,6 +325,17 @@ class Feature:
""" """
return OrderedDict() return OrderedDict()
def clone(self):
"""Create a clone that can be attached to a new `FunctionGraph`.
This default implementation returns `self`, which carries the
assumption that the `Feature` is essentially stateless. If a subclass
has state of its own that is in any way relative to a given
`FunctionGraph`, this method should be overridden with an
implementation that actually creates a fresh copy.
"""
return self
class Bookkeeper(Feature): class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
...@@ -389,6 +400,9 @@ class History(Feature): ...@@ -389,6 +400,9 @@ class History(Feature):
fgraph.checkpoint = GetCheckpoint(self, fgraph) fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph) fgraph.revert = partial(self.revert, fgraph)
def clone(self):
return type(self)()
def unpickle(self, fgraph): def unpickle(self, fgraph):
fgraph.checkpoint = GetCheckpoint(self, fgraph) fgraph.checkpoint = GetCheckpoint(self, fgraph)
fgraph.revert = partial(self.revert, fgraph) fgraph.revert = partial(self.revert, fgraph)
...@@ -516,6 +530,9 @@ class ReplaceValidate(History, Validator): ...@@ -516,6 +530,9 @@ class ReplaceValidate(History, Validator):
Validator.on_attach(self, fgraph) Validator.on_attach(self, fgraph)
self.unpickle(fgraph) self.unpickle(fgraph)
def clone(self):
return type(self)()
def unpickle(self, fgraph): def unpickle(self, fgraph):
History.unpickle(self, fgraph) History.unpickle(self, fgraph)
Validator.unpickle(self, fgraph) Validator.unpickle(self, fgraph)
...@@ -643,12 +660,17 @@ class NodeFinder(Bookkeeper): ...@@ -643,12 +660,17 @@ class NodeFinder(Bookkeeper):
def on_attach(self, fgraph): def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"): if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present") raise AlreadyThere("NodeFinder is already present")
if self.fgraph is not None and self.fgraph != fgraph: if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.") raise Exception("A NodeFinder instance can only serve one FunctionGraph.")
self.fgraph = fgraph self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph) fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph) Bookkeeper.on_attach(self, fgraph)
def clone(self):
return type(self)()
def on_detach(self, fgraph): def on_detach(self, fgraph):
""" """
Should remove any dynamically added functionality Should remove any dynamically added functionality
...@@ -751,6 +773,9 @@ class NoOutputFromInplace(Feature): ...@@ -751,6 +773,9 @@ class NoOutputFromInplace(Feature):
fgraph._no_output_from_inplace = self fgraph._no_output_from_inplace = self
def clone(self):
return type(self)(self.protected_out_ids)
def validate(self, fgraph): def validate(self, fgraph):
if not hasattr(fgraph, "destroyers"): if not hasattr(fgraph, "destroyers"):
return True return True
......
...@@ -900,7 +900,7 @@ class FunctionGraph(MetaObject): ...@@ -900,7 +900,7 @@ class FunctionGraph(MetaObject):
if attach_feature: if attach_feature:
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature.clone())
return e, equiv return e, equiv
def __getstate__(self): def __getstate__(self):
......
...@@ -534,6 +534,9 @@ class MergeFeature(Feature): ...@@ -534,6 +534,9 @@ class MergeFeature(Feature):
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach") self.on_import(fgraph, node, "on_attach")
def clone(self):
return type(self)()
def on_change_input(self, fgraph, node, i, r, new_r, reason): def on_change_input(self, fgraph, node, i, r, new_r, reason):
if node in self.nodes_seen: if node in self.nodes_seen:
# If inputs to a node change, it's not guaranteed that the node is # If inputs to a node change, it's not guaranteed that the node is
...@@ -2106,6 +2109,9 @@ class ChangeTracker(Feature): ...@@ -2106,6 +2109,9 @@ class ChangeTracker(Feature):
self.changed = False self.changed = False
self.nb_imported = 0 self.nb_imported = 0
def clone(self):
return type(self)()
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
self.nb_imported += 1 self.nb_imported += 1
self.changed = True self.changed = True
......
...@@ -15,7 +15,6 @@ import aesara.scalar.basic as aes ...@@ -15,7 +15,6 @@ import aesara.scalar.basic as aes
from aesara import compile from aesara import compile
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import features
from aesara.graph.basic import ( from aesara.graph.basic import (
Constant, Constant,
Variable, Variable,
...@@ -23,6 +22,7 @@ from aesara.graph.basic import ( ...@@ -23,6 +22,7 @@ from aesara.graph.basic import (
equal_computations, equal_computations,
io_toposort, io_toposort,
) )
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.opt import ( from aesara.graph.opt import (
...@@ -775,7 +775,7 @@ class MakeVectorPrinter(Printer): ...@@ -775,7 +775,7 @@ class MakeVectorPrinter(Printer):
pprint.assign(MakeVector, MakeVectorPrinter()) pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature(features.Feature): class ShapeFeature(Feature):
"""Graph optimizer for removing all calls to shape(). """Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with This optimizer replaces all Shapes and Subtensors of Shapes with
...@@ -1230,7 +1230,7 @@ class ShapeFeature(features.Feature): ...@@ -1230,7 +1230,7 @@ class ShapeFeature(features.Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
if hasattr(fgraph, "shape_feature"): if hasattr(fgraph, "shape_feature"):
raise features.AlreadyThere("This FunctionGraph already has a ShapeFeature") raise AlreadyThere("This FunctionGraph already has a ShapeFeature")
if hasattr(self, "fgraph") and self.fgraph != fgraph: if hasattr(self, "fgraph") and self.fgraph != fgraph:
raise Exception("This ShapeFeature is already attached to a graph") raise Exception("This ShapeFeature is already attached to a graph")
...@@ -1453,6 +1453,9 @@ class ShapeFeature(features.Feature): ...@@ -1453,6 +1453,9 @@ class ShapeFeature(features.Feature):
return True return True
def clone(self):
return type(self)()
class ShapeOptimizer(GlobalOptimizer): class ShapeOptimizer(GlobalOptimizer):
"""Optimizer that adds `ShapeFeature` as a feature.""" """Optimizer that adds `ShapeFeature` as a feature."""
...@@ -3275,7 +3278,7 @@ class FusionOptimizer(GlobalOptimizer): ...@@ -3275,7 +3278,7 @@ class FusionOptimizer(GlobalOptimizer):
self.optimizer = local_optimizer self.optimizer = local_optimizer
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(features.ReplaceValidate()) fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
did_something = True did_something = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论