提交 af1e743d authored 作者: ChienliMa's avatar ChienliMa

format change

上级 85beeaca
...@@ -22,9 +22,8 @@ import theano.compile.mode ...@@ -22,9 +22,8 @@ import theano.compile.mode
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput) In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
from theano.gof.graph import is_same_graph, clone_get_equiv from theano.gof.graph import is_same_graph
from theano.gof.op import ops_with_inner_function from theano.gof.op import ops_with_inner_function
from theano.gof.fg import FunctionGraph
import logging import logging
_logger = logging.getLogger('theano.compile.function_module') _logger = logging.getLogger('theano.compile.function_module')
...@@ -553,13 +552,14 @@ class Function(object): ...@@ -553,13 +552,14 @@ class Function(object):
def copy(self, share_memory=False): def copy(self, share_memory=False):
""" """
Copy this function. Copied function will have separated maker and fgraph Copy this function. Copied function will have separated maker and
with original function. User can choose whether to separate storage by fgraph with original function. User can choose whether to separate
changing the share_memory arguments storage by changing the share_memory arguments
--------------------- ---------------------
Params: Params:
share_memory -- { boolean } Default is False. When True, two function share_memory -- { boolean } Default is False. When True, two
share intermediate storages(storages except input and output storages) function share intermediate storages(storages except input and
output storages)
--------------------- ---------------------
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
...@@ -575,13 +575,6 @@ class Function(object): ...@@ -575,13 +575,6 @@ class Function(object):
maker = self.maker maker = self.maker
fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False) fg_cpy, memo = maker.fgraph.clone_get_equiv(attach_feature=False)
# use copied ins, outs and fgraph to init a maker
new_maker = maker.__class__(inputs=ins, outputs=outs, mode=maker.mode,
fgraph=fg_cpy, profile=maker.profile,
accept_inplace=maker.accept_inplace,
function_builder=maker.function_builder,
on_unused_input=maker.on_unused_input)
# construct new storage_map that map new variable to old storage # construct new storage_map that map new variable to old storage
# so that the ensuing function shares storage with the original one # so that the ensuing function shares storage with the original one
new_storage_map = {} new_storage_map = {}
...@@ -594,22 +587,26 @@ class Function(object): ...@@ -594,22 +587,26 @@ class Function(object):
# But to be safe for now as it isn't documented and we aren't sure # But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map. # it is well tested, we don't share the part of the storage_map.
for key in storage_map.keys(): for key in storage_map.keys():
if key not in self.maker.fgraph.outputs: if key not in maker.fgraph.outputs:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
# copy input storages if it's mutable # copy input storages if it's mutable
input_storage = [] input_storage = []
for i in self.maker.inputs: for i in maker.inputs:
storage = getattr(i, 'value', None) storage = getattr(i, 'value', None)
if isinstance(i.variable, theano.tensor.Constant) or\ if isinstance(i.variable, theano.tensor.Constant) or\
not i.mutable: not i.mutable:
input_storage.append(storage ) input_storage.append(storage)
else: else:
input_storage.append( copy.deepcopy[storage]) input_storage.append(copy.deepcopy[storage])
new_func = new_maker.create(input_storage, storage_map=new_storage_map) # reinitialize new maker and create new function
return maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
return new_func mode=maker.mode, profile=maker.profile,
on_unused_input=maker.on_unused_input,
function_builder=maker.function_builder,
accept_inplace=maker.accept_inplace).create(
input_storage, storage_map=new_storage_map)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论