提交 599d6376 authored 作者: ChienliMa's avatar ChienliMa

Add assertion in map_storage to assert storage given by input/output_storage and…

Add assertion in map_storage to assert storage given by input/output_storage and storage_map is the same.
上级 cbdb4315
......@@ -560,9 +560,9 @@ class Function(object):
share_memory -- { boolean } Default is False. When True, two
function share intermediate storages(storages except input and
output storages). Otherwise two function will only share partial
storages( see method __copy__() ). If two functions share memory
and allow_gc=False, this will increase executing speed and save
memory.
storages( see method __copy__() ) and same maker. If two functions
share memory and allow_gc=False, this will increase executing speed
and save memory.
---------------------
Returns:
func -- Copied theano.Function
......@@ -611,6 +611,13 @@ class Function(object):
storage = getattr(in_cpy, 'value', None)
input_storage.append(storage)
# pop out input_storage in storage_map and only use storage of In
# instances to initialize the make, so that we can avoid
# storage conflictions in link.map_storage()
assert len(fg_cpy.inputs) == len(input_storage)
for in_var in fg_cpy.inputs:
new_storage_map.pop(in_var)
# reinitialize new maker and create new function
return maker.__class__(inputs=ins, outputs=maker.outputs,
fgraph=fg_cpy,
......
......@@ -539,18 +539,24 @@ def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs.
if input_storage is None:
input_storage = [[None] for input in fgraph.inputs]
else:
assert len(fgraph.inputs) == len(input_storage)
if storage_map is None:
storage_map = {}
# add input storage into storage_map
for r, storage in zip(fgraph.inputs, input_storage):
storage_map[r] = storage
if r in storage_map:
assert storage_map[r] is storage, (
"Given input_storage conflicts with storage in given"
"storage_map. Given input_storage: ", storage,
"Storage in storage_map: ", storage_map[r])
else:
storage_map[r] = storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
......@@ -559,8 +565,14 @@ def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
# allocate output storage
if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage)
for r, storage in izip(fgraph.outputs, output_storage):
storage_map[r] = storage
for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, (
"Given output_storage conflicts with storage in given"
"storage_map. Given output_storage: ", storage,
"Storage in storage_map: ", storage_map[r])
else:
storage_map[r] = storage
# allocate storage for intermediate computation
for node in order:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论