提交 dc36c99b authored 作者: ChienliMa's avatar ChienliMa

Finish draft, it work. Start working on test and improve the code

上级 81f7134f
...@@ -19,7 +19,6 @@ from theano import gof ...@@ -19,7 +19,6 @@ from theano import gof
from functools import partial from functools import partial
from theano.compat import izip from theano.compat import izip
import theano.compile.mode import theano.compile.mode
from theano.compile import SharedVariable
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput) In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
...@@ -555,11 +554,6 @@ class Function(object): ...@@ -555,11 +554,6 @@ class Function(object):
storages( see method __copy__() ) and same maker. If two functions storages( see method __copy__() ) and same maker. If two functions
share memory and allow_gc=False, this will increase executing speed share memory and allow_gc=False, this will increase executing speed
and save memory. and save memory.
swap -- { dict } Default is None. If not None, it should be a
dictionay {String: theano.SharedVariable} that map original
SharedVariable's name(string) to new SharedVariable.
Two SharedVariable should have the same theano type.
--------------------- ---------------------
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
...@@ -571,23 +565,23 @@ class Function(object): ...@@ -571,23 +565,23 @@ class Function(object):
# copy fgraph and get memo # copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# if swap is not None, swap SharedVariable # swap SharedVariable if need
if swap != None: if swap != None:
assert type(swap) is dict, ("Parameter swap should be a dict." self._swapSV(swap, ins, fg_cpy, memo)
"Given:", type(swap)) #used to prevent
swapSV(swap, ins, fg_cpy, memo) swapped_sv = swap.keys()
storage_map = self.fn.storage_map storage_map = self.fn.storage_map
new_storage_map = {} new_storage_map = {}
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
if share_memory: if share_memory:
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
for key in storage_map.keys(): for key in storage_map.keys():
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
...@@ -624,7 +618,7 @@ class Function(object): ...@@ -624,7 +618,7 @@ class Function(object):
return f_cpy return f_cpy
def _swapSV(swap, ins, fg_cpy, memo): def _swapSV(self, swap, ins, fg_cpy, memo):
""" """
Hidden auxiliary function, used to swap SharedVariable in maker.inputs Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph. and maker.fgraph.
...@@ -652,6 +646,8 @@ class Function(object): ...@@ -652,6 +646,8 @@ class Function(object):
# Main code start here # Main code start here
exist_name = [ i.variable.name for i in ins ] exist_name = [ i.variable.name for i in ins ]
for name, sv in swap.iteritems(): for name, sv in swap.iteritems():
assert type(name) is str, ("The SharedVariable need to be replaced"
"should be indicated by its name(string). Given:",type(name))
# check if SharedVariable exist # check if SharedVariable exist
if name not in exist_name: if name not in exist_name:
print ("WARNING: SharedVariable named: %s not found") % name print ("WARNING: SharedVariable named: %s not found") % name
...@@ -661,7 +657,7 @@ class Function(object): ...@@ -661,7 +657,7 @@ class Function(object):
sv = sv.clone() sv = sv.clone()
# replace SharedVariable in maker's In instances # replace SharedVariable in maker's In instances
for i in maker.inputs: for i in ins:
if i.variable.name == name: if i.variable.name == name:
checkSV( i.variable, sv) checkSV( i.variable, sv)
i.variable, i.value = sv, sv.container i.variable, i.value = sv, sv.container
...@@ -669,14 +665,19 @@ class Function(object): ...@@ -669,14 +665,19 @@ class Function(object):
# replace SharedVariable in fgraph and memo # replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems(): for var_ori, var_cpy in memo.iteritems():
if isinstance(ori_in, SharedVariable) and ori_in.name == name: if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv) checkSV(var_cpy, sv)
# replace variable in fgraph # replace variable in fgraph
fgraph_cpy.replace(var_cpy, sv) # modify fg_cpy.iputs
for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name:
fg_cpy.inputs[i] = sv
break
fg_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var # modify memo so that ori_var->this shared var
memo[var_ori] = sv memo[var_ori] = sv
break break
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
t0 = time.time() t0 = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论