提交 8a040b98 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add feature that keeps track of full rewrite history

上级 646a734d
...@@ -438,6 +438,169 @@ class History(Feature): ...@@ -438,6 +438,169 @@ class History(Feature):
self.history[fgraph] = h self.history[fgraph] = h
class FullHistory(Feature):
"""Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
.. testcode::
import pytensor
import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.features import FullHistory
from pytensor.graph.rewriting.utils import rewrite_graph
x = pt.scalar("x")
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
fg = FunctionGraph(outputs=[out])
history = FullHistory()
fg.attach_feature(history)
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
# Replay rewrites
history.start()
pytensor.dprint(fg)
with pytensor.config.change_flags(optimizer_verbose = True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.next())
.. testoutput::
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
>> MergeOptimizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> local_mul_canonizer
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_logsoftmax
LogSoftmax{axis=None} [id A] 0
└─ x [id B]
.. testcode::
# Or in reverse
with pytensor.config.change_flags(optimizer_verbose=True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.prev())
.. testoutput::
>> local_logsoftmax
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_mul_canonizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> MergeOptimizer
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
.. testcode::
# Or go to any step
pytensor.dprint(history.goto(2))
.. testoutput::
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
"""
def __init__(self):
self.fw = []
self.bw = []
self.pointer = -1
self.fg = None
def on_attach(self, fgraph):
if self.fg is not None:
raise ValueError("Full History already attached to another fgraph")
self.fg = fgraph
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
self.pointer += 1
def goto(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
history_len = len(self.bw)
pointer = self.pointer
assert 0 <= checkpoint <= history_len
verbose = config.optimizer_verbose
# Go backwards
while pointer > checkpoint - 1:
reverse_fn = self.bw[pointer]
if verbose:
print(reverse_fn.reason) # noqa: T201
reverse_fn()
pointer -= 1
# Go forward
while pointer < checkpoint - 1:
pointer += 1
forward_fn = self.fw[pointer]
if verbose:
print(forward_fn.reason) # noqa: T201
forward_fn()
# Remove history changes caused by the foward/backward!
self.bw = self.bw[:history_len]
self.fw = self.fw[:history_len]
self.pointer = pointer
return self.fg
def start(self):
return self.goto(0)
def end(self):
return self.goto(len(self.bw))
def prev(self):
if self.pointer < 0:
return self.fg
else:
return self.goto(self.pointer)
def next(self):
if self.pointer >= len(self.bw) - 1:
return self.fg
else:
return self.goto(self.pointer + 2)
class Validator(Feature): class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"] pickle_rm_attr = ["validate", "consistent"]
......
import pytest import pytest
from pytensor.graph.basic import Apply, Variable import pytensor.tensor as pt
from pytensor.graph.features import Feature, NodeFinder, ReplaceValidate from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
...@@ -119,3 +121,33 @@ class TestReplaceValidate: ...@@ -119,3 +121,33 @@ class TestReplaceValidate:
capres = capsys.readouterr() capres = capsys.readouterr()
assert "rewriting: validate failed on node Op1.0" in capres.out assert "rewriting: validate failed on node Op1.0" in capres.out
def test_full_history():
x = pt.scalar("x")
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
fg = FunctionGraph(outputs=[out], clone=True, copy_inputs=False)
history = FullHistory()
fg.attach_feature(history)
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
history.start()
assert equal_computations(fg.outputs, [out])
history.end()
assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])
history.prev()
assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
for i in range(10):
history.prev()
assert equal_computations(fg.outputs, [out])
history.goto(2)
assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
for i in range(10):
history.next()
assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论