提交 cca20eb7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

QoL improvements to InteractiveRewrite widget

上级 3876e73d
...@@ -4,7 +4,7 @@ import traitlets ...@@ -4,7 +4,7 @@ import traitlets
from IPython.display import display from IPython.display import display
from pytensor.graph import FunctionGraph, Variable, rewrite_graph from pytensor.graph import FunctionGraph, Variable, rewrite_graph
from pytensor.graph.features import FullHistory from pytensor.graph.features import AlreadyThere, FullHistory
class CodeBlockWidget(anywidget.AnyWidget): class CodeBlockWidget(anywidget.AnyWidget):
...@@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget): ...@@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget):
class InteractiveRewrite: class InteractiveRewrite:
""" """
A class that wraps a graph history object with interactive widgets Visualize a graph history through a series of rewrites.
to navigate through history and display the graph at each step.
Includes an option to display the reason for the last change.
""" """
def __init__(self, fg, display_reason=True): def __init__(
self,
fg,
display_reason=True,
rewrite_options: dict | None = None,
dprint_options: dict | None = None,
):
""" """
Initialize with a history object that has a goto method
and tracks a FunctionGraph.
Parameters: Parameters:
----------- -----------
fg : FunctionGraph (or Variables) fg : FunctionGraph (or Variables)
The function graph to track The function graph to track
display_reason : bool, optional display_reason : bool, optional
Whether to display the reason for each rewrite Whether to display the reason for each rewrite
rewrite_options : dict, optional
Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)}
print_options : dict, optional
Print options passed to `debugprint` used to generate the text representation of the graph.
Useful options are {'print_shape': True, 'print_op_info': True}
""" """
self.dprint_options = dprint_options or {}
self.rewrite_options = rewrite_options or dict(
include=("fast_run",), exclude=("inplace",)
)
self.history = FullHistory(callback=self._history_callback) self.history = FullHistory(callback=self._history_callback)
if not isinstance(fg, FunctionGraph): if not isinstance(fg, FunctionGraph):
outs = [fg] if isinstance(fg, Variable) else fg outs = [fg] if isinstance(fg, Variable) else fg
fg = FunctionGraph(outputs=outs) fg = FunctionGraph(outputs=outs)
try:
fg.attach_feature(self.history) fg.attach_feature(self.history)
except AlreadyThere:
self.history.end()
self.updating_from_callback = False # Flag to prevent recursion self.updating_from_callback = False # Flag to prevent recursion
self.code_widget = CodeBlockWidget(content="") self.code_widget = CodeBlockWidget(content="")
...@@ -163,7 +175,7 @@ class InteractiveRewrite: ...@@ -163,7 +175,7 @@ class InteractiveRewrite:
reason = "" reason = ""
else: else:
reason = self.history.fw[self.history.pointer].reason reason = self.history.fw[self.history.pointer].reason
reason = getattr(reason, "name", str(reason)) reason = getattr(reason, "name", None) or str(reason)
self.reason_label.value = f""" self.reason_label.value = f"""
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'> <div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
...@@ -172,7 +184,9 @@ class InteractiveRewrite: ...@@ -172,7 +184,9 @@ class InteractiveRewrite:
""" """
# Update the graph display # Update the graph display
self.code_widget.content = self.history.fg.dprint(file="str") self.code_widget.content = self.history.fg.dprint(
file="str", **self.dprint_options
)
# Update slider range if history length has changed # Update slider range if history length has changed
history_len = len(self.history.fw) + 1 history_len = len(self.history.fw) + 1
...@@ -189,14 +203,13 @@ class InteractiveRewrite: ...@@ -189,14 +203,13 @@ class InteractiveRewrite:
f"History: {self.history.pointer + 1}/{history_len - 1}" f"History: {self.history.pointer + 1}/{history_len - 1}"
) )
def rewrite(self, *args, include=("fast_run",), exclude=("inplace",), **kwargs): def rewrite(self, *args, **kwargs):
"""Apply rewrites to the current graph""" """Apply rewrites to the current graph"""
rewrite_graph( rewrite_graph(
self.history.fg, self.history.fg,
*args, *args,
include=include,
exclude=exclude,
**kwargs, **kwargs,
**self.rewrite_options,
clone=False, clone=False,
) )
self._update_display() self._update_display()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论