提交 ce152cb6 authored 作者: Frederic's avatar Frederic

add thunk.inputs and thunk.output automatically at many places.

上级 cf9903f8
...@@ -1680,6 +1680,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1680,6 +1680,8 @@ class _Linker(gof.link.LocalLinker):
storage_map, storage_map,
compute_map, compute_map,
no_recycling) no_recycling)
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
# Right now there is no op that when called check if # Right now there is no op that when called check if
# its ouputs are computed and don't recompute itself. # its ouputs are computed and don't recompute itself.
......
...@@ -1498,6 +1498,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1498,6 +1498,9 @@ class OpWiseCLinker(link.LocalLinker):
storage_map, storage_map,
compute_map, compute_map,
no_recycling)] no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
finally: finally:
node.op._op_use_c_code = old_value node.op._op_use_c_code = old_value
......
...@@ -526,6 +526,8 @@ class PerformLinker(LocalLinker): ...@@ -526,6 +526,8 @@ class PerformLinker(LocalLinker):
storage_map, storage_map,
compute_map, compute_map,
no_recycling)] no_recycling)]
thunks[-1].inputs = [storage_map[v] for v in node.inputs]
thunks[-1].outputs = [storage_map[v] for v in node.outputs]
finally: finally:
node.op._op_use_c_code = old_value node.op._op_use_c_code = old_value
......
...@@ -431,6 +431,8 @@ class PureOp(object): ...@@ -431,6 +431,8 @@ class PureOp(object):
# compute output value once with test inputs to validate graph # compute output value once with test inputs to validate graph
thunk = node.op.make_thunk(node, storage_map, compute_map, thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[]) no_recycling=[])
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk() required = thunk()
assert not required # We provided all inputs assert not required # We provided all inputs
......
...@@ -843,6 +843,9 @@ class VM_Linker(link.LocalLinker): ...@@ -843,6 +843,9 @@ class VM_Linker(link.LocalLinker):
compute_map, compute_map,
no_recycling) no_recycling)
for node in order] for node in order]
for node, thunk in zip(order, thunks):
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
if self.allow_gc: if self.allow_gc:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论