提交 8db93fe1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename CheckStackTraceOptimization to CheckStackTraceRewriter

上级 92a3b2a6
......@@ -11,7 +11,7 @@ from aesara.compile.function.types import Supervisor
from aesara.configdefaults import config
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.opt import (
CheckStackTraceOptimization,
CheckStackTraceRewriter,
GraphRewriter,
MergeOptimizer,
NodeProcessingGraphRewriter,
......@@ -271,7 +271,7 @@ if config.check_stack_trace in ("raise", "warn", "log"):
if config.check_stack_trace == "off":
_tags = ()
optdb.register("CheckStackTrace", CheckStackTraceOptimization(), *_tags, position=-1)
optdb.register("CheckStackTrace", CheckStackTraceRewriter(), *_tags, position=-1)
del _tags
......
......@@ -3091,13 +3091,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
class CheckStackTraceFeature(Feature):
def on_import(self, fgraph, node, reason):
# In optdb we only register the CheckStackTraceOptimization when
# config.check_stack_trace is not off but we also double check here.
# In `optdb` we only register the `CheckStackTraceRewriter` when
# `config.check_stack_trace` is not off, but we also double check here.
if config.check_stack_trace != "off" and not check_stack_trace(fgraph, "all"):
if config.check_stack_trace == "raise":
raise AssertionError(
"Empty stack trace! The optimization that inserted this variable is "
+ str(reason)
f"Empty stack trace. The rewrite that inserted this variable is {reason}."
)
elif config.check_stack_trace in ("log", "warn"):
apply_nodes_to_check = fgraph.apply_nodes
......@@ -3109,22 +3108,19 @@ class CheckStackTraceFeature(Feature):
(
"",
0,
"Empty stack trace! The optimization that"
+ "inserted this variable is "
+ str(reason),
f"Empty stack trace. The rewrite that inserted this variable is {reason}.",
"",
)
]
]
if config.check_stack_trace == "warn":
warnings.warn(
"Empty stack trace! The optimization that inserted this variable is"
+ str(reason)
f"Empty stack trace. The rewrite that inserted this variable is {reason}."
)
class CheckStackTraceOptimization(GraphRewriter):
"""Optimizer that serves to add `CheckStackTraceOptimization` as a feature."""
class CheckStackTraceRewriter(GraphRewriter):
"""Rewriter that serves to add `CheckStackTraceRewriter` as a feature."""
def add_requirements(self, fgraph):
if not hasattr(fgraph, "CheckStackTraceFeature"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论