提交 4be377bb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3586 from nouiz/fg_history

[ENH] Keep an history of only 1 checkpoint of FunctionGraph
...@@ -147,7 +147,6 @@ def test_misc(): ...@@ -147,7 +147,6 @@ def test_misc():
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
consistent(g) consistent(g)
chk = g.checkpoint()
PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g) PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
new_e = add(x, y) new_e = add(x, y)
...@@ -156,9 +155,6 @@ def test_misc(): ...@@ -156,9 +155,6 @@ def test_misc():
g.replace(new_e, dot(add_in_place(x, y), transpose_view(x))) g.replace(new_e, dot(add_in_place(x, y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
inconsistent(g) inconsistent(g)
g.revert(chk)
consistent(g)
assert str(g) == "[TransposeView(TransposeView(TransposeView(TransposeView(x))))]"
###################### ######################
...@@ -213,7 +209,6 @@ def test_destroyers_loop(): ...@@ -213,7 +209,6 @@ def test_destroyers_loop():
e1 = add(x, y) e1 = add(x, y)
e2 = add(y, x) e2 = add(y, x)
g = Env([x, y, z], [e1, e2]) g = Env([x, y, z], [e1, e2])
chk = g.checkpoint()
consistent(g) consistent(g)
g.replace_validate(e1, add_in_place(x, y)) g.replace_validate(e1, add_in_place(x, y))
consistent(g) consistent(g)
...@@ -223,7 +218,12 @@ def test_destroyers_loop(): ...@@ -223,7 +218,12 @@ def test_destroyers_loop():
except InconsistencyError: except InconsistencyError:
pass pass
consistent(g) consistent(g)
g.revert(chk)
x, y, z = inputs()
e1 = add(x, y)
e2 = add(y, x)
g = Env([x, y, z], [e1, e2])
consistent(g)
g.replace_validate(e2, add_in_place(y, x)) g.replace_validate(e2, add_in_place(y, x))
consistent(g) consistent(g)
try: try:
......
...@@ -127,9 +127,12 @@ class GetCheckpoint: ...@@ -127,9 +127,12 @@ class GetCheckpoint:
def __init__(self, history, fgraph): def __init__(self, history, fgraph):
self.h = history self.h = history
self.fgraph = fgraph self.fgraph = fgraph
self.nb = 0
def __call__(self): def __call__(self):
return len(self.h.history[self.fgraph]) self.h.history[self.fgraph] = []
self.nb += 1
return self.nb
class LambdExtract: class LambdExtract:
...@@ -147,6 +150,13 @@ class LambdExtract: ...@@ -147,6 +150,13 @@ class LambdExtract:
class History(Feature): class History(Feature):
"""Keep an history of changes to an FunctionGraph.
This history can be reverted up to the last checkpoint.. We can
revert to only 1 point in the past. This limit was added to lower
the memory usage.
"""
pickle_rm_attr = ["checkpoint", "revert"] pickle_rm_attr = ["checkpoint", "revert"]
def __init__(self): def __init__(self):
...@@ -187,7 +197,8 @@ class History(Feature): ...@@ -187,7 +197,8 @@ class History(Feature):
""" """
h = self.history[fgraph] h = self.history[fgraph]
self.history[fgraph] = None self.history[fgraph] = None
while len(h) > checkpoint: assert fgraph.checkpoint.nb == checkpoint
while h:
f = h.pop() f = h.pop()
f() f()
self.history[fgraph] = h self.history[fgraph] = h
...@@ -314,6 +325,7 @@ class ReplaceValidate(History, Validator): ...@@ -314,6 +325,7 @@ class ReplaceValidate(History, Validator):
raise raise
if verbose: if verbose:
print(reason, r, new_r) print(reason, r, new_r)
# The return is needed by replace_all_validate_remove
return chk return chk
def replace_all_validate_remove(self, fgraph, replacements, def replace_all_validate_remove(self, fgraph, replacements,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论