提交 6847e573 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix inconsistent Feature.on_attach checks

A few core `Feature` implementations were using `assert`s and generic `Exception` instead of `AlreadyThere`, which unnecessarily produces errors when already-attached `Feature` requirements are added by `Rewriter`s.
上级 e51be01c
...@@ -31,7 +31,7 @@ from aesara.compile.ops import OutputGuard, _output_guard ...@@ -31,7 +31,7 @@ from aesara.compile.ops import OutputGuard, _output_guard
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort from aesara.graph.basic import Variable, io_toposort
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization from aesara.graph.features import AlreadyThere, BadOptimization
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import InconsistencyError, MethodNotDefined from aesara.graph.utils import InconsistencyError, MethodNotDefined
from aesara.link.basic import Container, LocalLinker from aesara.link.basic import Container, LocalLinker
...@@ -1206,7 +1206,9 @@ class _VariableEquivalenceTracker: ...@@ -1206,7 +1206,9 @@ class _VariableEquivalenceTracker:
self.fgraph = None self.fgraph = None
def on_attach(self, fgraph): def on_attach(self, fgraph):
assert self.fgraph is None if self.fgraph is not None:
raise AlreadyThere()
self.equiv = {} self.equiv = {}
self.active_nodes = set() self.active_nodes = set()
self.inactive_nodes = set() self.inactive_nodes = set()
......
...@@ -347,24 +347,12 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -347,24 +347,12 @@ class DestroyHandler(Bookkeeper): # noqa
""" """
# Do the checking # if any(hasattr(fgraph, attr) for attr in ("destroyers", "destroy_handler")):
already_there = False raise AlreadyThere("DestroyHandler feature is already present")
if self.fgraph is fgraph:
already_there = True if self.fgraph is not None and self.fgraph != fgraph:
if self.fgraph is not None:
raise Exception( raise Exception(
"A DestroyHandler instance can only serve one" "A DestroyHandler instance can only serve one FunctionGraph"
" FunctionGraph. (Matthew 6:24)"
)
for attr in ("destroyers", "destroy_handler"):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin."
) )
# Annotate the FunctionGraph # # Annotate the FunctionGraph #
......
...@@ -641,14 +641,10 @@ class NodeFinder(Bookkeeper): ...@@ -641,14 +641,10 @@ class NodeFinder(Bookkeeper):
self.d = {} self.d = {}
def on_attach(self, fgraph): def on_attach(self, fgraph):
if self.fgraph is not None:
raise Exception(
"A NodeFinder instance can only serve one " "FunctionGraph."
)
if hasattr(fgraph, "get_nodes"): if hasattr(fgraph, "get_nodes"):
raise AlreadyThere( raise AlreadyThere("NodeFinder is already present")
"NodeFinder is already present or in conflict" " with another plugin." if self.fgraph is not None and self.fgraph != fgraph:
) raise Exception("A NodeFinder instance can only serve one FunctionGraph.")
self.fgraph = fgraph self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph) fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph) Bookkeeper.on_attach(self, fgraph)
......
...@@ -31,7 +31,7 @@ from aesara.graph.basic import ( ...@@ -31,7 +31,7 @@ from aesara.graph.basic import (
io_toposort, io_toposort,
vars_between, vars_between,
) )
from aesara.graph.features import Feature, NodeFinder from aesara.graph.features import AlreadyThere, Feature, NodeFinder
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.utils import AssocList, InconsistencyError from aesara.graph.utils import AssocList, InconsistencyError
...@@ -496,7 +496,9 @@ class MergeFeature(Feature): ...@@ -496,7 +496,9 @@ class MergeFeature(Feature):
""" """
def on_attach(self, fgraph): def on_attach(self, fgraph):
assert not hasattr(fgraph, "merge_feature") if hasattr(fgraph, "merge_feature"):
raise AlreadyThere()
fgraph.merge_feature = self fgraph.merge_feature = self
self.seen_atomics = set() self.seen_atomics = set()
...@@ -2111,6 +2113,8 @@ class ChangeTracker(Feature): ...@@ -2111,6 +2113,8 @@ class ChangeTracker(Feature):
self.changed = False self.changed = False
def on_attach(self, fgraph): def on_attach(self, fgraph):
if hasattr(fgraph, "change_tracker"):
raise AlreadyThere()
fgraph.change_tracker = self fgraph.change_tracker = self
def on_detach(self, fgraph): def on_detach(self, fgraph):
......
...@@ -1229,13 +1229,13 @@ class ShapeFeature(features.Feature): ...@@ -1229,13 +1229,13 @@ class ShapeFeature(features.Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
if getattr(self, "fgraph", None): if hasattr(fgraph, "shape_feature"):
raise ValueError("This ShapeFeature is already attached to a graph") raise features.AlreadyThere("This FunctionGraph already has a ShapeFeature")
self.fgraph = fgraph if hasattr(self, "fgraph") and self.fgraph != fgraph:
raise Exception("This ShapeFeature is already attached to a graph")
if hasattr(fgraph, "shape_feature"): self.fgraph = fgraph
raise ValueError("This FunctionGraph already has a ShapeFeature")
fgraph.shape_feature = self fgraph.shape_feature = self
# Must be local to the object as otherwise we reuse the same # Must be local to the object as otherwise we reuse the same
......
...@@ -119,6 +119,10 @@ check_untyped_defs = False ...@@ -119,6 +119,10 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.compile.function.pfunc]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.function.types] [mypy-aesara.compile.function.types]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论