提交 8a040b98 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add feature that keeps track of full rewrite history

上级 646a734d
......@@ -438,6 +438,169 @@ class History(Feature):
self.history[fgraph] = h
class FullHistory(Feature):
"""Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
.. testcode::
import pytensor
import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.features import FullHistory
from pytensor.graph.rewriting.utils import rewrite_graph
x = pt.scalar("x")
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
fg = FunctionGraph(outputs=[out])
history = FullHistory()
fg.attach_feature(history)
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
# Replay rewrites
history.start()
pytensor.dprint(fg)
with pytensor.config.change_flags(optimizer_verbose = True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.next())
.. testoutput::
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
>> MergeOptimizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> local_mul_canonizer
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_logsoftmax
LogSoftmax{axis=None} [id A] 0
└─ x [id B]
.. testcode::
# Or in reverse
with pytensor.config.change_flags(optimizer_verbose=True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.prev())
.. testoutput::
>> local_logsoftmax
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_mul_canonizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> MergeOptimizer
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
.. testcode::
# Or go to any step
pytensor.dprint(history.goto(2))
.. testoutput::
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
"""
def __init__(self):
self.fw = []
self.bw = []
self.pointer = -1
self.fg = None
def on_attach(self, fgraph):
if self.fg is not None:
raise ValueError("Full History already attached to another fgraph")
self.fg = fgraph
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
self.pointer += 1
def goto(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
history_len = len(self.bw)
pointer = self.pointer
assert 0 <= checkpoint <= history_len
verbose = config.optimizer_verbose
# Go backwards
while pointer > checkpoint - 1:
reverse_fn = self.bw[pointer]
if verbose:
print(reverse_fn.reason) # noqa: T201
reverse_fn()
pointer -= 1
# Go forward
while pointer < checkpoint - 1:
pointer += 1
forward_fn = self.fw[pointer]
if verbose:
print(forward_fn.reason) # noqa: T201
forward_fn()
# Remove history changes caused by the foward/backward!
self.bw = self.bw[:history_len]
self.fw = self.fw[:history_len]
self.pointer = pointer
return self.fg
def start(self):
return self.goto(0)
def end(self):
return self.goto(len(self.bw))
def prev(self):
if self.pointer < 0:
return self.fg
else:
return self.goto(self.pointer)
def next(self):
if self.pointer >= len(self.bw) - 1:
return self.fg
else:
return self.goto(self.pointer + 2)
class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]
......
import pytest
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.features import Feature, NodeFinder, ReplaceValidate
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
......@@ -119,3 +121,33 @@ class TestReplaceValidate:
capres = capsys.readouterr()
assert "rewriting: validate failed on node Op1.0" in capres.out
def test_full_history():
x = pt.scalar("x")
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
fg = FunctionGraph(outputs=[out], clone=True, copy_inputs=False)
history = FullHistory()
fg.attach_feature(history)
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
history.start()
assert equal_computations(fg.outputs, [out])
history.end()
assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])
history.prev()
assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
for i in range(10):
history.prev()
assert equal_computations(fg.outputs, [out])
history.goto(2)
assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
for i in range(10):
history.next()
assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论