提交 81f7134f authored 作者: ChienliMa's avatar ChienliMa

draft of swaping sharedvariable

上级 127d36c3
......@@ -19,6 +19,7 @@ from theano import gof
from functools import partial
from theano.compat import izip
import theano.compile.mode
from theano.compile import SharedVariable
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op
......@@ -541,7 +542,7 @@ class Function(object):
"""
return self.copy()
def copy(self, share_memory=False):
def copy(self, share_memory=False, swap=None):
"""
Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate
......@@ -554,6 +555,11 @@ class Function(object):
storages( see method __copy__() ) and same maker. If two functions
share memory and allow_gc=False, this will increase executing speed
and save memory.
swap -- { dict } Default is None. If not None, it should be a
dictionay {String: theano.SharedVariable} that map original
SharedVariable's name(string) to new SharedVariable.
Two SharedVariable should have the same theano type.
---------------------
Returns:
func -- Copied theano.Function
......@@ -565,26 +571,31 @@ class Function(object):
# copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# if swap is not None, swap SharedVariable
if swap != None:
assert type(swap) is dict, ("Parameter swap should be a dict."
"Given:", type(swap))
swapSV(swap, ins, fg_cpy, memo)
storage_map = self.fn.storage_map
new_storage_map = {}
# If share_memory, Construct new storage_map that map new variable
# to old storage, so that the ensuing function shares storage with
# the original one.
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
if share_memory:
i_o_vars = maker.fgraph.outputs + maker.fgraph.inputs
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
for key in storage_map.keys():
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
input_storage = []
assert len(ins) == len(fg_cpy.inputs)
for in_ori, in_cpy in zip(maker.inputs, ins):
for in_ori, in_cpy, in_v in zip(maker.inputs, ins, fg_cpy.inputs):
# Since we reuse original Out instances, copied In instances
# should use the original variabls as their variables and
# updates. Otherwise the compilation will fail at function
......@@ -613,6 +624,59 @@ class Function(object):
return f_cpy
def _swapSV(swap, ins, fg_cpy, memo):
"""
Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph.
--------------------
Params:
swap -- { dict } The same as docs in copy()
ins -- { list } List of In instances to be modified.
fg_cpy -- { FounctionGraph } Copied FunctionGraph.
memo -- { dict } Dict that map old variables to new variables.
--------------------
Returns:
None
"""
def checkSV( sv_ori, sv_rpl ):
"""
Assert two SharedVariable follow some restirctions:
1. same type
2. same shape????
"""
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with origianl one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type )
# Main code start here
exist_name = [ i.variable.name for i in ins ]
for name, sv in swap.iteritems():
# check if SharedVariable exist
if name not in exist_name:
print ("WARNING: SharedVariable named: %s not found") % name
continue
# we don't use the originally defined variable
sv = sv.clone()
# replace SharedVariable in maker's In instances
for i in maker.inputs:
if i.variable.name == name:
checkSV( i.variable, sv)
i.variable, i.value = sv, sv.container
break
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(ori_in, SharedVariable) and ori_in.name == name:
checkSV(var_cpy, sv)
# replace variable in fgraph
fgraph_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var
memo[var_ori] = sv
break
def __call__(self, *args, **kwargs):
profile = self.profile
t0 = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论