提交 ce29622c authored 作者: ChienliMa's avatar ChienliMa

Add some docs to FunctionGraph and minor changes of function_module.copy()

上级 c6c3d454
...@@ -571,27 +571,13 @@ class Function(object): ...@@ -571,27 +571,13 @@ class Function(object):
# copy SymbolocKits # copy SymbolocKits
ins, outs = copy.deepcopy([self.maker.inputs, self.maker.outputs]) ins, outs = copy.deepcopy([self.maker.inputs, self.maker.outputs])
# get copied input, output variables # copy fgraph and get memo
in_vars = [ i.variable for i in ins ]
out_vars = [ o.variable for o in outs ]
# contruct memo that map old variables to new variables
memo = {}
for old_i, new_i in zip([i.variable for i in self.maker.inputs],
in_vars):
memo[old_i] = new_i
for old_o, new_o in zip([o.variable for o in self.maker.outputs],
out_vars):
memo[old_o] = new_o
# contruct new fgraph with new vars and complete the memo
memo = clone_get_equiv( in_vars, out_vars, memo )
new_fgraph = FunctionGraph(in_vars, out_vars)
# re-initialize new FunctionMaker
maker = self.maker maker = self.maker
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# use copied ins, outs and fgraph to init a maker
new_maker = FunctionMaker(inputs=ins, outputs=outs, mode=maker.mode, new_maker = FunctionMaker(inputs=ins, outputs=outs, mode=maker.mode,
fgraph=new_fgraph, profile=maker.profile, fgraph=fg_cpy, profile=maker.profile,
accept_inplace=maker.accept_inplace, accept_inplace=maker.accept_inplace,
function_builder=maker.function_builder, function_builder=maker.function_builder,
on_unused_input=maker.on_unused_input) on_unused_input=maker.on_unused_input)
...@@ -602,20 +588,30 @@ class Function(object): ...@@ -602,20 +588,30 @@ class Function(object):
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
for key in storage_map.keys(): for key in storage_map.keys():
# output_storages should not be shared # output_storages should not be shared
if key not in self.maker.fgraph.outputs and \ # if key not in self.maker.fgraph.outputs and \
memo.has_key(key): # memo.has_key(key):
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
# copy input storages and link function with new storage_map # copy input storages if it's mutable
input_storage = copy.copy([getattr(i, 'value', None) for i in ins]) input_storage = []
new_func = new_maker.create(input_storage, storage_map=new_storage_map) for i in self.maker.inputs:
storage = getattr(i, 'value', None)
if isinstance(i.variable, theano.tensor.Constant) or\
not i.mutable:
input_storage.append(storage )
else:
input_storage.append( copy.deepcopy[storage])
new_func = new_maker.create(input_storage, \
storage_map=new_storage_map)
# share immutable SharedVariable's storage # share immutable SharedVariable's storage
for (input, _1, _2), here, there in zip(self.indices, # for (input, _1, _2), here, there in zip(self.indices,
self.input_storage, # self.input_storage,
new_func.input_storage): # new_func.input_storage):
if not input.mutable: # if isinstance(i.variable, theano.tensor.Constant) or \
there.data = here.data # not input.mutable:
# there.data = here.data
return new_func return new_func
......
...@@ -243,7 +243,7 @@ class T_function(unittest.TestCase): ...@@ -243,7 +243,7 @@ class T_function(unittest.TestCase):
def test_copy_share_memory(self): def test_copy_share_memory(self):
x = T.fscalar('x') x = T.fscalar('x')
y = T.tanh((x+2)/(x+0.2)**2) y = T.tanh((x+2)/(x-0.2)**2)
# test for PerformaLinker, will cover VM_linker later # test for PerformaLinker, will cover VM_linker later
ori = theano.function([x], [y], mode="FAST_COMPILE") ori = theano.function([x], [y], mode="FAST_COMPILE")
...@@ -256,13 +256,11 @@ class T_function(unittest.TestCase): ...@@ -256,13 +256,11 @@ class T_function(unittest.TestCase):
fgraph_cpy = cpy.maker.fgraph fgraph_cpy = cpy.maker.fgraph
# assert intermediate and Constants storages are shared # assert intermediate and Constants storages are shared
i_o_variables = fgraph_cpy.inputs i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values() ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys(): for key in storage_map_cpy.keys():
if key not in i_o_variables or isinstance(key, theano.tensor.Constant): if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
print key
storage = storage_map_cpy[key] storage = storage_map_cpy[key]
print [ storage is s for s in ori_storages]
self.assertTrue( any([ storage is s for s in ori_storages])) self.assertTrue( any([ storage is s for s in ori_storages]))
# assert storages of SharedVariable without updates are shared # assert storages of SharedVariable without updates are shared
......
...@@ -276,8 +276,19 @@ class FunctionGraph(utils.object2): ...@@ -276,8 +276,19 @@ class FunctionGraph(utils.object2):
return True return True
return False return False
# import # ### import ###
def __import_r__(self, variable, reason): def __import_r__(self, variables, reason):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
----------------------
Parameters;
variables -- Iterable if variables needed to import
reason -- String. Reason.
----------------------
Returns:
None
"""
global NullType global NullType
if NullType is None: if NullType is None:
from .null_type import NullType from .null_type import NullType
...@@ -296,6 +307,12 @@ class FunctionGraph(utils.object2): ...@@ -296,6 +307,12 @@ class FunctionGraph(utils.object2):
self.variables.add(variable) self.variables.add(variable)
def __import__(self, apply_node, check=True, reason=None): def __import__(self, apply_node, check=True, reason=None):
"""
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
"""
node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set. # in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论