提交 8464ed71 authored 作者: Frederic's avatar Frederic 提交者: ChienliMa

Fix profiling and add name to Function.copy()

上级 2ba562ae
......@@ -15,7 +15,7 @@ import warnings
import numpy
import theano
from theano import gof
from theano import config, gof
from functools import partial
from theano.compat import izip
from theano.gof import graph
......@@ -542,7 +542,8 @@ class Function(object):
"""
return self.copy()
def copy(self, share_memory=False, swap=None, delete_updates=False):
def copy(self, share_memory=False, swap=None, delete_updates=False,
name=None, profile=None):
"""
Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate
......@@ -562,6 +563,11 @@ class Function(object):
delete_updates -- { boolean } Default is False. If True, Copied
function will not have update.
name -- { string } If provided, will be the name of the new
Function. Otherwise, it will be old + " copy"
profile -- as theano.function profile parameter
---------------------
Returns:
func -- Copied theano.Function
......@@ -664,11 +670,26 @@ class Function(object):
for key in storage_map.keys():
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
if not name and self.name:
name = self.name + " copy"
input_storage = [i.value for i in ins]
# reinitialize new maker and create new function
if profile is None:
profile = config.profile
# profile -> True or False
if profile is True:
if name:
message = name
else:
message = str(maker.profile.message) + " copy"
profile = theano.compile.profiling.ProfileStats(message=message)
# profile -> object
elif type(profile) == str:
profile = theano.compile.profiling.ProfileStats(message=profile)
f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
mode=maker.mode, profile=maker.profile,
mode=maker.mode, profile=profile,
on_unused_input=maker.on_unused_input,
function_builder=maker.function_builder,
accept_inplace=maker.accept_inplace
......@@ -699,6 +720,8 @@ class Function(object):
f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable]
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy
def __call__(self, *args, **kwargs):
......
......@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if profile is True:
profile = ProfileStats(message=name)
# profile -> object
if type(profile) == str:
elif type(profile) == str:
profile = ProfileStats(message=profile)
# profile is typically either False or an object at this point.
# No need to block other objects being passed through though. It might be
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论