提交 f8493348 authored 作者: James Bergstra's avatar James Bergstra

Fixed infer_reuse_pattern to not mark graph inputs for recycling.

上级 0c1c4710
...@@ -57,9 +57,11 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -57,9 +57,11 @@ def infer_reuse_pattern(env, outputs_to_disown):
This list (or set) is also refered to as no_recycling sometimes, especially by linker code. This list (or set) is also refered to as no_recycling sometimes, especially by linker code.
""" """
rval1 = set() rval = set()
for o in outputs_to_disown: for o in outputs_to_disown:
view_tree_set(view_map_root(o), rval1) view_tree_set(view_map_root(o), rval)
# remove from rval all of the inputs, constants, values.
rval = set(r for r in rval if r.owner is not None)
if 1: if 1:
# DEBUG STUFF # DEBUG STUFF
...@@ -67,11 +69,10 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -67,11 +69,10 @@ def infer_reuse_pattern(env, outputs_to_disown):
rval0 = _old_infer_reuse_pattern(env, outputs_to_disown) rval0 = _old_infer_reuse_pattern(env, outputs_to_disown)
rval0_set = set(rval0) rval0_set = set(rval0)
for blah in rval0_set: for r in rval0_set:
print blah assert r in rval
assert blah in rval1
return rval1 return rval
def _old_infer_reuse_pattern(env, outputs_to_disown): def _old_infer_reuse_pattern(env, outputs_to_disown):
""" """
...@@ -556,6 +557,7 @@ class SanityCheckFunction(Function): ...@@ -556,6 +557,7 @@ class SanityCheckFunction(Function):
super(SanityCheckFunction, self).__init__(*args, **kwargs) super(SanityCheckFunction, self).__init__(*args, **kwargs)
self.others = others self.others = others
self.check_equal = check_equal self.check_equal = check_equal
# DEPRECATED? Is this just for DualLinker?
def __setitem__(self, item, value): def __setitem__(self, item, value):
super(SanityCheckFunction, self).__setitem__(item, value) super(SanityCheckFunction, self).__setitem__(item, value)
...@@ -780,6 +782,7 @@ class FunctionMaker(object): ...@@ -780,6 +782,7 @@ class FunctionMaker(object):
input_storage_lists.append([input_storage_i]) input_storage_lists.append([input_storage_i])
defaults.append((self.required[i], self.refeed[i], input_storage_i)) defaults.append((self.required[i], self.refeed[i], input_storage_i))
# Get a function instance # Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage_lists) _fn, _i, _o = self.linker.make_thunk(input_storage = input_storage_lists)
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self.return_none, self) fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self.return_none, self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论