Unverified 提交 c58f10be authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix rewrite weakref leak (#1660)

Calling `lru_cache` on instance methods causes a leak. Fixed as suggested in https://rednafi.com/python/lru-cache-on-methods/
上级 a99feb55
...@@ -1073,6 +1073,7 @@ class OpToRewriterTracker: ...@@ -1073,6 +1073,7 @@ class OpToRewriterTracker:
defaultdict(lambda: defaultdict(list)) defaultdict(lambda: defaultdict(list))
) )
self.untracked_rewrites: list[NodeRewriter] = [] self.untracked_rewrites: list[NodeRewriter] = []
self.get_trackers = functools.cache(self._get_trackers)
self._cached_composed_mro = None self._cached_composed_mro = None
def add_tracker(self, rw: NodeRewriter): def add_tracker(self, rw: NodeRewriter):
...@@ -1080,6 +1081,7 @@ class OpToRewriterTracker: ...@@ -1080,6 +1081,7 @@ class OpToRewriterTracker:
if self._cached_composed_mro is not None: if self._cached_composed_mro is not None:
# We shouldn't actually add_trackers after the first call to get_trackers # We shouldn't actually add_trackers after the first call to get_trackers
# But just to be safe we kill the cache here # But just to be safe we kill the cache here
self.get_trackers = functools.cache(self._get_trackers)
self._cached_composed_mro = None self._cached_composed_mro = None
tracks = rw.tracks() tracks = rw.tracks()
...@@ -1107,8 +1109,7 @@ class OpToRewriterTracker: ...@@ -1107,8 +1109,7 @@ class OpToRewriterTracker:
else: else:
self.tracked_instances[c].append(rw) self.tracked_instances[c].append(rw)
@functools.cache def _get_trackers(self, op: Op) -> list[NodeRewriter]:
def get_trackers(self, op: Op) -> list[NodeRewriter]:
"""Get all the rewrites applicable to an `Op`.""" """Get all the rewrites applicable to an `Op`."""
if self._cached_composed_mro is None: if self._cached_composed_mro is None:
......
import gc
import operator
import pytest import pytest
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Constant, equal_computations from pytensor.graph.basic import Apply, Constant, equal_computations
from pytensor.graph.features import Feature from pytensor.graph.features import Feature
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -930,3 +934,44 @@ def test_OpToRewriterTracker(): ...@@ -930,3 +934,44 @@ def test_OpToRewriterTracker():
local_rewriter_2, local_rewriter_2,
local_rewriter_1, local_rewriter_1,
] ]
def test_rewrite_weakref_leak():
"""Check we don't have weakref leak on our rewrites"""
def _growth(limit=10, peak_stats={}):
"""Vendoring of objgraph.growth
Source: https://github.com/mgedmin/objgraph/blob/94b1ca61a11109547442701800292dcfc7f59fc8/objgraph.py#L253
"""
gc.collect()
objects = gc.get_objects()
stats = {}
for o in objects:
n = type(o).__name__
stats[n] = stats.get(n, 0) + 1
deltas = {}
for name, count in stats.items():
old_count = peak_stats.get(name, 0)
if count > old_count:
deltas[name] = count - old_count
peak_stats[name] = count
deltas = sorted(deltas.items(), key=operator.itemgetter(1), reverse=True)
if limit:
deltas = deltas[:limit]
return [(name, stats[name], delta) for name, delta in deltas]
x = vector("x")
y = exp(x)
for i in range(20):
rewrite_graph(y, clone=False)
res = _growth()
# Only start checking after warmup
if i > 15:
assert not res, "Object counts are still growing"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论