提交 4be377bb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3586 from nouiz/fg_history

[ENH] Keep an history of only 1 checkpoint of FunctionGraph
......@@ -147,7 +147,6 @@ def test_misc():
e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = Env([x, y, z], [e])
consistent(g)
chk = g.checkpoint()
PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]"
new_e = add(x, y)
......@@ -156,9 +155,6 @@ def test_misc():
g.replace(new_e, dot(add_in_place(x, y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
inconsistent(g)
g.revert(chk)
consistent(g)
assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
######################
......@@ -213,7 +209,6 @@ def test_destroyers_loop():
e1 = add(x, y)
e2 = add(y, x)
g = Env([x, y, z], [e1, e2])
chk = g.checkpoint()
consistent(g)
g.replace_validate(e1, add_in_place(x, y))
consistent(g)
......@@ -223,7 +218,12 @@ def test_destroyers_loop():
except InconsistencyError:
pass
consistent(g)
g.revert(chk)
x, y, z = inputs()
e1 = add(x, y)
e2 = add(y, x)
g = Env([x, y, z], [e1, e2])
consistent(g)
g.replace_validate(e2, add_in_place(y, x))
consistent(g)
try:
......
......@@ -127,9 +127,12 @@ class GetCheckpoint:
def __init__(self, history, fgraph):
self.h = history
self.fgraph = fgraph
self.nb = 0
def __call__(self):
return len(self.h.history[self.fgraph])
self.h.history[self.fgraph] = []
self.nb += 1
return self.nb
class LambdExtract:
......@@ -147,6 +150,13 @@ class LambdExtract:
class History(Feature):
"""Keep an history of changes to an FunctionGraph.
This history can be reverted up to the last checkpoint.. We can
revert to only 1 point in the past. This limit was added to lower
the memory usage.
"""
pickle_rm_attr = ["checkpoint", "revert"]
def __init__(self):
......@@ -187,7 +197,8 @@ class History(Feature):
"""
h = self.history[fgraph]
self.history[fgraph] = None
while len(h) > checkpoint:
assert fgraph.checkpoint.nb == checkpoint
while h:
f = h.pop()
f()
self.history[fgraph] = h
......@@ -314,6 +325,7 @@ class ReplaceValidate(History, Validator):
raise
if verbose:
print(reason, r, new_r)
# The return is needed by replace_all_validate_remove
return chk
def replace_all_validate_remove(self, fgraph, replacements,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论