提交 dc36c99b authored 作者: ChienliMa's avatar ChienliMa

Finish draft, it work. Start working on test and improve the code

上级 81f7134f
......@@ -19,7 +19,6 @@ from theano import gof
from functools import partial
from theano.compat import izip
import theano.compile.mode
from theano.compile import SharedVariable
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op
......@@ -555,11 +554,6 @@ class Function(object):
storages( see method __copy__() ) and same maker. If two functions
share memory and allow_gc=False, this will increase executing speed
and save memory.
swap -- { dict } Default is None. If not None, it should be a
dictionay {String: theano.SharedVariable} that map original
SharedVariable's name(string) to new SharedVariable.
Two SharedVariable should have the same theano type.
---------------------
Returns:
func -- Copied theano.Function
......@@ -571,23 +565,23 @@ class Function(object):
# copy fgraph and get memo
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# if swap is not None, swap SharedVariable
# swap SharedVariable if need
if swap != None:
assert type(swap) is dict, ("Parameter swap should be a dict."
"Given:", type(swap))
swapSV(swap, ins, fg_cpy, memo)
self._swapSV(swap, ins, fg_cpy, memo)
#used to prevent
swapped_sv = swap.keys()
storage_map = self.fn.storage_map
new_storage_map = {}
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
if share_memory:
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
for key in storage_map.keys():
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
......@@ -624,7 +618,7 @@ class Function(object):
return f_cpy
def _swapSV(swap, ins, fg_cpy, memo):
def _swapSV(self, swap, ins, fg_cpy, memo):
"""
Hidden auxiliary function, used to swap SharedVariable in maker.inputs
and maker.fgraph.
......@@ -652,6 +646,8 @@ class Function(object):
# Main code start here
exist_name = [ i.variable.name for i in ins ]
for name, sv in swap.iteritems():
assert type(name) is str, ("The SharedVariable need to be replaced"
"should be indicated by its name(string). Given:",type(name))
# check if SharedVariable exist
if name not in exist_name:
print ("WARNING: SharedVariable named: %s not found") % name
......@@ -661,7 +657,7 @@ class Function(object):
sv = sv.clone()
# replace SharedVariable in maker's In instances
for i in maker.inputs:
for i in ins:
if i.variable.name == name:
checkSV( i.variable, sv)
i.variable, i.value = sv, sv.container
......@@ -669,14 +665,19 @@ class Function(object):
# replace SharedVariable in fgraph and memo
for var_ori, var_cpy in memo.iteritems():
if isinstance(ori_in, SharedVariable) and ori_in.name == name:
if isinstance(var_ori, theano.compile.SharedVariable) and var_ori.name == name:
checkSV(var_cpy, sv)
# replace variable in fgraph
fgraph_cpy.replace(var_cpy, sv)
# modify fg_cpy.iputs
for i in xrange(len(fg_cpy.inputs)):
if fg_cpy.inputs[i].name == name:
fg_cpy.inputs[i] = sv
break
fg_cpy.replace(var_cpy, sv)
# modify memo so that ori_var->this shared var
memo[var_ori] = sv
break
def __call__(self, *args, **kwargs):
profile = self.profile
t0 = time.time()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论