提交 e865889c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make Supervisor an actual Feature

上级 b7589469
...@@ -26,7 +26,7 @@ from aesara.graph.basic import ( ...@@ -26,7 +26,7 @@ from aesara.graph.basic import (
graph_inputs, graph_inputs,
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import Feature, PreserveVariableAttributes from aesara.graph.features import AlreadyThere, Feature, PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph from aesara.graph.op import HasInnerGraph
from aesara.graph.utils import InconsistencyError, get_variable_trace_string from aesara.graph.utils import InconsistencyError, get_variable_trace_string
...@@ -130,7 +130,7 @@ def fgraph_updated_vars(fgraph, expanded_inputs): ...@@ -130,7 +130,7 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
return updated_vars return updated_vars
class Supervisor: class Supervisor(Feature):
""" """
Listener for FunctionGraph events which makes sure that no Listener for FunctionGraph events which makes sure that no
operation overwrites the contents of protected Variables. The operation overwrites the contents of protected Variables. The
...@@ -139,8 +139,19 @@ class Supervisor: ...@@ -139,8 +139,19 @@ class Supervisor:
""" """
def __init__(self, protected): def __init__(self, protected):
self.fgraph = None
self.protected = list(protected) self.protected = list(protected)
def on_attach(self, fgraph):
if hasattr(fgraph, "_supervisor"):
raise AlreadyThere(f"A Supervisor is already attached to {fgraph}.")
if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("This Feature is already associated with a FunctionGraph")
fgraph._supervisor = self
self.fgraph = fgraph
def validate(self, fgraph): def validate(self, fgraph):
if config.cycle_detection == "fast" and hasattr(fgraph, "has_destroyers"): if config.cycle_detection == "fast" and hasattr(fgraph, "has_destroyers"):
if fgraph.has_destroyers(self.protected): if fgraph.has_destroyers(self.protected):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论