提交 d0cfadcb authored 作者: James Bergstra's avatar James Bergstra

added allow_gc option to PerformLinker, which frees intermediate results during…

added allow_gc option to PerformLinker, which frees intermediate results during execution when they are no longer needed.
上级 d18ed5ad
...@@ -217,6 +217,13 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -217,6 +217,13 @@ def map_storage(env, order, input_storage, output_storage):
return input_storage, output_storage, storage_map return input_storage, output_storage, storage_map
def clear_storage_thunk(stg):
"""This is useful for inserting thunks that zero-out storage, which allows memory to be freed by gc."""
def thunk():
stg[0] = None
thunk.outputs = []
thunk.inputs = [stg]
return thunk
def streamline(env, thunks, order, no_recycling = [], profiler = None): def streamline(env, thunks, order, no_recycling = [], profiler = None):
"""WRITEME""" """WRITEME"""
...@@ -270,8 +277,10 @@ class PerformLinker(LocalLinker): ...@@ -270,8 +277,10 @@ class PerformLinker(LocalLinker):
the L{Env} in the order given by L{Env.toposort}. the L{Env} in the order given by L{Env.toposort}.
""" """
def __init__(self): def __init__(self, allow_gc=False):
#TODO: set allow_gc = True by default, when it works with the c&py linker
self.env = None self.env = None
self.allow_gc = allow_gc
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
""" """
...@@ -302,7 +311,20 @@ class PerformLinker(LocalLinker): ...@@ -302,7 +311,20 @@ class PerformLinker(LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
thunks = [] thunks = []
new_order = []
input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(env, order, input_storage, output_storage)
#for freeing memory
if self.allow_gc:
last_user = {}
computed = set()
for node in order:
for idx, input in enumerate(node.inputs):
last_user[input] = (node, idx)
for output in node.outputs:
computed.add(output)
for node in order: for node in order:
node_input_storage = tuple(storage_map[input] for input in node.inputs) node_input_storage = tuple(storage_map[input] for input in node.inputs)
node_output_storage = tuple(storage_map[output] for output in node.outputs) node_output_storage = tuple(storage_map[output] for output in node.outputs)
...@@ -315,6 +337,20 @@ class PerformLinker(LocalLinker): ...@@ -315,6 +337,20 @@ class PerformLinker(LocalLinker):
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunk.perform = p thunk.perform = p
thunks.append(thunk) thunks.append(thunk)
new_order.append(node)
if self.allow_gc:
for idx, input in enumerate(node.inputs):
if input not in computed:
continue
if input in env.outputs:
continue
if (node, idx) == last_user[input]:
#print '... zeroing', id(storage_map[input])
thunks.append(clear_storage_thunk(storage_map[input]))
new_order.append(node)
if no_recycling is True: if no_recycling is True:
#True is like some special code for *everything*. #True is like some special code for *everything*.
...@@ -325,11 +361,11 @@ class PerformLinker(LocalLinker): ...@@ -325,11 +361,11 @@ class PerformLinker(LocalLinker):
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
# The function that actually runs your program is one of the f's in streamline. # The function that actually runs your program is one of the f's in streamline.
f = streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = streamline(env, thunks, new_order, no_recycling = no_recycling, profiler = profiler)
return f, [Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order thunks, new_order
......
...@@ -133,7 +133,7 @@ class TestWrapLinker: ...@@ -133,7 +133,7 @@ class TestWrapLinker:
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk() fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
...@@ -148,7 +148,7 @@ class TestWrapLinker: ...@@ -148,7 +148,7 @@ class TestWrapLinker:
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker()], wrap).make_thunk() fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
......
...@@ -2220,7 +2220,7 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0 ...@@ -2220,7 +2220,7 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
t_r = as_tensor(random_projection) t_r = as_tensor(random_projection)
#random projection of o onto t_r #random projection of o onto t_r
cost = sum(t_r * o_output) cost = sum(t_r * o_output) #This sum() is defined above, it's not the builtin sum.
cost_fn = function(tensor_pt, cost) cost_fn = function(tensor_pt, cost)
num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps) num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论