提交 cca20eb7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

QoL improvements to InteractiveRewrite widget

上级 3876e73d
......@@ -4,7 +4,7 @@ import traitlets
from IPython.display import display
from pytensor.graph import FunctionGraph, Variable, rewrite_graph
from pytensor.graph.features import FullHistory
from pytensor.graph.features import AlreadyThere, FullHistory
class CodeBlockWidget(anywidget.AnyWidget):
......@@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget):
class InteractiveRewrite:
"""
A class that wraps a graph history object with interactive widgets
to navigate through history and display the graph at each step.
Includes an option to display the reason for the last change.
Visualize a graph history through a series of rewrites.
"""
def __init__(self, fg, display_reason=True):
def __init__(
self,
fg,
display_reason=True,
rewrite_options: dict | None = None,
dprint_options: dict | None = None,
):
"""
Initialize with a history object that has a goto method
and tracks a FunctionGraph.
Parameters:
-----------
fg : FunctionGraph (or Variables)
The function graph to track
display_reason : bool, optional
Whether to display the reason for each rewrite
rewrite_options : dict, optional
Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)}
print_options : dict, optional
Print options passed to `debugprint` used to generate the text representation of the graph.
Useful options are {'print_shape': True, 'print_op_info': True}
"""
self.dprint_options = dprint_options or {}
self.rewrite_options = rewrite_options or dict(
include=("fast_run",), exclude=("inplace",)
)
self.history = FullHistory(callback=self._history_callback)
if not isinstance(fg, FunctionGraph):
outs = [fg] if isinstance(fg, Variable) else fg
fg = FunctionGraph(outputs=outs)
fg.attach_feature(self.history)
try:
fg.attach_feature(self.history)
except AlreadyThere:
self.history.end()
self.updating_from_callback = False # Flag to prevent recursion
self.code_widget = CodeBlockWidget(content="")
......@@ -163,7 +175,7 @@ class InteractiveRewrite:
reason = ""
else:
reason = self.history.fw[self.history.pointer].reason
reason = getattr(reason, "name", str(reason))
reason = getattr(reason, "name", None) or str(reason)
self.reason_label.value = f"""
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
......@@ -172,7 +184,9 @@ class InteractiveRewrite:
"""
# Update the graph display
self.code_widget.content = self.history.fg.dprint(file="str")
self.code_widget.content = self.history.fg.dprint(
file="str", **self.dprint_options
)
# Update slider range if history length has changed
history_len = len(self.history.fw) + 1
......@@ -189,14 +203,13 @@ class InteractiveRewrite:
f"History: {self.history.pointer + 1}/{history_len - 1}"
)
def rewrite(self, *args, include=("fast_run",), exclude=("inplace",), **kwargs):
def rewrite(self, *args, **kwargs):
"""Apply rewrites to the current graph"""
rewrite_graph(
self.history.fg,
*args,
include=include,
exclude=exclude,
**kwargs,
**self.rewrite_options,
clone=False,
)
self._update_display()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论