提交 4ab08dea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix a memory leak in the C implementation of DimShuffle

上级 3d8553f7
...@@ -31,6 +31,8 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, ...@@ -31,6 +31,8 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
PyArrayObject *transposed_input = PyArrayObject *transposed_input =
(PyArrayObject *)PyArray_Transpose(_input, &permute); (PyArrayObject *)PyArray_Transpose(_input, &permute);
Py_DECREF(_input);
PyDimMem_FREE(permute.ptr); PyDimMem_FREE(permute.ptr);
npy_intp *res_shape = PyArray_DIMS(transposed_input); npy_intp *res_shape = PyArray_DIMS(transposed_input);
...@@ -68,7 +70,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, ...@@ -68,7 +70,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape, *res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
NPY_CORDER); NPY_CORDER);
/* Py_XDECREF(transposed_input); */ Py_DECREF(transposed_input);
PyDimMem_FREE(reshape_shape.ptr); PyDimMem_FREE(reshape_shape.ptr);
......
import math import math
import tracemalloc
from copy import copy from copy import copy
import numpy as np import numpy as np
...@@ -141,6 +142,41 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -141,6 +142,41 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
# Confirm the broadcasted value in the output # Confirm the broadcasted value in the output
assert np.array_equiv(outputs[0].storage[0], 2039) assert np.array_equiv(outputs[0].storage[0], 2039)
@pytest.mark.parametrize("inplace", [True, False])
def test_memory_leak(self, inplace):
import gc
n = 100_000
x = aesara.shared(np.ones(n, dtype=np.float64))
y = x.dimshuffle([0, "x"])
y.owner.op.inplace = inplace
f = aesara.function([], y, mode=Mode(optimizer=None))
assert len(f.maker.fgraph.apply_nodes) == 2
assert isinstance(f.maker.fgraph.toposort()[0].op, DimShuffle)
assert f.maker.fgraph.toposort()[0].op.inplace is inplace
tracemalloc.start()
blocks_last = None
block_diffs = []
for i in range(50):
x.set_value(np.ones(n))
_ = f()
_ = gc.collect()
blocks_i, _ = tracemalloc.get_traced_memory()
if blocks_last is not None:
blocks_diff = (blocks_i - blocks_last) // 10 ** 3
block_diffs.append(blocks_diff)
blocks_last = blocks_i
tracemalloc.stop()
assert np.allclose(np.mean(block_diffs), 0)
class TestBroadcast: class TestBroadcast:
# this is to allow other types to reuse this class to test their ops # this is to allow other types to reuse this class to test their ops
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论