提交 be54333b authored 作者: Frederic Bastien's avatar Frederic Bastien

simpler test for the memory leak

上级 edf434c0
......@@ -1059,34 +1059,42 @@ class test_fusion(unittest.TestCase):
print "time", self.do(mode, shared, shp=(1000,1000),gpu=False, assert_len_topo=False,slice=s, nb_repeat=100)
def tes_memory_leak(self, mode=compile.mode.predefined_modes['FAST_RUN'], shared_fn=shared, shp=(3000,3000), gpu=False, nb_repeat=30, assert_len_topo=True, slice=None):
def tes_memory_leak(self, mode=compile.mode.Mode('c', 'merge'), shared_fn=shared, shp=(3000,3000), gpu=False, nb_repeat=30, assert_len_topo=True, slice=None):
"""
param shared_fn: if None, will use compile.function
verify that the elemwise fusion work
Test with and without DimShuffle
"""
#TODO: disable the canonizer?
fx, fy = fmatrices('xy')
fx = fmatrices('x')
fxv = numpy.zeros(shp, dtype='float32')+ 2
fyv = numpy.zeros(shp, dtype='float32')+ 3
cases = [
(fx+fy,(fx,fy),(fxv,fyv),1,fxv+fyv,'float32'),#1
(fx,(fx),(fxv),'float32'),#1
]
import gc, pdb, objgraph, weakref
d={}
dl=[]
v1=None
for id, [g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype] in enumerate(cases):
mode=compile.mode.Mode('c', 'merge')
for id, [g, sym_inputs, val_inputs, out_dtype] in enumerate(cases):
for zzzz in range(nb_repeat):
v=numpy.zeros(shp, dtype=out_dtype)
gc.collect();gc.collect();gc.collect()
print 'v1',v1
v1=weakref.ref(v)
# print 'v1',v1
# v1=weakref.ref(v)
out=shared_fn(v,'out')
f = pfunc(sym_inputs,[],updates=[(out,out+g)],mode=mode)
pdb.set_trace()
# f = pfunc(sym_inputs,[],updates=[(out,out+g)],mode=mode)
# f = pfunc([fx],[],updates=[(out,out+fx)],mode=mode)
# f = pfunc([fx],out+fx,mode=mode)
# f = compile.function([fx,out],[out+fx],mode=mode)#no memory leak.
f = compile.function([fx,compile.In(variable=out, value=out.container, mutable=None)],
[out+fx],mode=mode)#if mutable is True or False, their is a memory leak
del v
gc.collect();gc.collect();gc.collect()
pdb.set_trace()
if True:
if False:
gc.collect();gc.collect();gc.collect()
nd=objgraph.typestats()
print 'key, old val, new val, diff'
......@@ -1097,7 +1105,7 @@ class test_fusion(unittest.TestCase):
d=nd
# pdb.set_trace()
if True:
if False:
gc.collect();gc.collect();gc.collect()
ndl=objgraph.by_type('list')
ll=[]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论