提交 81f7134f authored 作者: ChienliMa's avatar ChienliMa

draft of swaping sharedvariable

上级 127d36c3
...@@ -19,6 +19,7 @@ from theano import gof ...@@ -19,6 +19,7 @@ from theano import gof
from functools import partial from functools import partial
from theano.compat import izip from theano.compat import izip
import theano.compile.mode import theano.compile.mode
from theano.compile import SharedVariable
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput) In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
...@@ -541,7 +542,7 @@ class Function(object): ...@@ -541,7 +542,7 @@ class Function(object):
""" """
return self.copy() return self.copy()
def copy(self, share_memory=False): def copy(self, share_memory=False, swap=None):
""" """
Copy this function. Copied function will have separated maker and Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate fgraph with original function. User can choose whether to separate
...@@ -554,6 +555,11 @@ class Function(object): ...@@ -554,6 +555,11 @@ class Function(object):
storages( see method __copy__() ) and same maker. If two functions storages( see method __copy__() ) and same maker. If two functions
share memory and allow_gc=False, this will increase executing speed share memory and allow_gc=False, this will increase executing speed
and save memory. and save memory.
swap -- { dict } Default is None. If not None, it should be a
dictionay {String: theano.SharedVariable} that map original
SharedVariable's name(string) to new SharedVariable.
Two SharedVariable should have the same theano type.
--------------------- ---------------------
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
...@@ -565,26 +571,31 @@ class Function(object): ...@@ -565,26 +571,31 @@ class Function(object):
# copy fgraph and get memo # copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# if swap is not None, swap SharedVariable
if swap != None:
assert type(swap) is dict, ("Parameter swap should be a dict."
"Given:", type(swap))
swapSV(swap, ins, fg_cpy, memo)
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
new_storage_map = {} new_storage_map = {}
# If share_memory, Construct new storage_map that map new variable if share_memory:
# to old storage, so that the ensuing function shares storage with # Construct new storage_map that map new variable to old storage,
# the original one. # so that the ensuing function shares storage with the original one
# TODO: We could share the output storage, but we must make sure # TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This # 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to # is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too. # use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure # But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map. # it is well tested, we don't share the part of the storage_map.
if share_memory:
i_o_vars = maker.fgraph.outputs + maker.fgraph.inputs
for key in storage_map.keys(): for key in storage_map.keys():
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
input_storage = [] input_storage = []
assert len(ins) == len(fg_cpy.inputs) assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy in zip(maker.inputs, ins):
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# Since we reuse original Out instances, copied In instances # Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and # should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function # updates. Otherwise the compilation will fail at function
...@@ -613,6 +624,59 @@ class Function(object): ...@@ -613,6 +624,59 @@ class Function(object):
return f_cpy return f_cpy
def _swapSV(swap, ins, fg_cpy, memo):
"""
Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph.
--------------------
Params:
swap -- { dict } The same as docs in copy()
ins -- { list } List of In instances to be modified.
fg_cpy -- { FounctionGraph } Copied FunctionGraph.
memo -- { dict } Dict that map old variables to new variables.
--------------------
Returns:
None
"""
def checkSV( sv_ori, sv_rpl ):
"""
Assert two SharedVariable follow some restirctions:
1. same type
2. same shape????
"""
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type )
# Main code start here
exist_name = [ i.variable.name for i in ins ]
for name, sv in swap.iteritems():
# check if SharedVariable exist
if name not in exist_name:
print ("WARNING: SharedVariable named: %s not found") % name
continue
# we don't use the originally defined variable
sv = sv.clone()
# replace SharedVariable in maker's In instances
for i in maker.inputs:
if i.variable.name == name:
checkSV( i.variable, sv)
i.variable, i.value = sv, sv.container
break
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(ori_in, SharedVariable) and ori_in.name == name:
checkSV(var_cpy, sv)
# replace variable in fgraph
fgraph_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var
memo[var_ori] = sv
break
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
t0 = time.time() t0 = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论