提交 8db93fe1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename CheckStackTraceOptimization to CheckStackTraceRewriter

上级 92a3b2a6
...@@ -11,7 +11,7 @@ from aesara.compile.function.types import Supervisor ...@@ -11,7 +11,7 @@ from aesara.compile.function.types import Supervisor
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.opt import ( from aesara.graph.opt import (
CheckStackTraceOptimization, CheckStackTraceRewriter,
GraphRewriter, GraphRewriter,
MergeOptimizer, MergeOptimizer,
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
...@@ -271,7 +271,7 @@ if config.check_stack_trace in ("raise", "warn", "log"): ...@@ -271,7 +271,7 @@ if config.check_stack_trace in ("raise", "warn", "log"):
if config.check_stack_trace == "off": if config.check_stack_trace == "off":
_tags = () _tags = ()
optdb.register("CheckStackTrace", CheckStackTraceOptimization(), *_tags, position=-1) optdb.register("CheckStackTrace", CheckStackTraceRewriter(), *_tags, position=-1)
del _tags del _tags
......
...@@ -3091,13 +3091,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): ...@@ -3091,13 +3091,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
class CheckStackTraceFeature(Feature): class CheckStackTraceFeature(Feature):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
# In optdb we only register the CheckStackTraceOptimization when # In `optdb` we only register the `CheckStackTraceRewriter` when
# config.check_stack_trace is not off but we also double check here. # `config.check_stack_trace` is not off, but we also double check here.
if config.check_stack_trace != "off" and not check_stack_trace(fgraph, "all"): if config.check_stack_trace != "off" and not check_stack_trace(fgraph, "all"):
if config.check_stack_trace == "raise": if config.check_stack_trace == "raise":
raise AssertionError( raise AssertionError(
"Empty stack trace! The optimization that inserted this variable is " f"Empty stack trace. The rewrite that inserted this variable is {reason}."
+ str(reason)
) )
elif config.check_stack_trace in ("log", "warn"): elif config.check_stack_trace in ("log", "warn"):
apply_nodes_to_check = fgraph.apply_nodes apply_nodes_to_check = fgraph.apply_nodes
...@@ -3109,22 +3108,19 @@ class CheckStackTraceFeature(Feature): ...@@ -3109,22 +3108,19 @@ class CheckStackTraceFeature(Feature):
( (
"", "",
0, 0,
"Empty stack trace! The optimization that" f"Empty stack trace. The rewrite that inserted this variable is {reason}.",
+ "inserted this variable is "
+ str(reason),
"", "",
) )
] ]
] ]
if config.check_stack_trace == "warn": if config.check_stack_trace == "warn":
warnings.warn( warnings.warn(
"Empty stack trace! The optimization that inserted this variable is" f"Empty stack trace. The rewrite that inserted this variable is {reason}."
+ str(reason)
) )
class CheckStackTraceOptimization(GraphRewriter): class CheckStackTraceRewriter(GraphRewriter):
"""Optimizer that serves to add `CheckStackTraceOptimization` as a feature.""" """Rewriter that serves to add `CheckStackTraceRewriter` as a feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
if not hasattr(fgraph, "CheckStackTraceFeature"): if not hasattr(fgraph, "CheckStackTraceFeature"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论