提交 8464ed71 authored 作者: Frederic's avatar Frederic 提交者: ChienliMa

Fix profiling and add name to Function.copy()

上级 2ba562ae
...@@ -15,7 +15,7 @@ import warnings ...@@ -15,7 +15,7 @@ import warnings
import numpy import numpy
import theano import theano
from theano import gof from theano import config, gof
from functools import partial from functools import partial
from theano.compat import izip from theano.compat import izip
from theano.gof import graph from theano.gof import graph
...@@ -542,7 +542,8 @@ class Function(object): ...@@ -542,7 +542,8 @@ class Function(object):
""" """
return self.copy() return self.copy()
def copy(self, share_memory=False, swap=None, delete_updates=False): def copy(self, share_memory=False, swap=None, delete_updates=False,
name=None, profile=None):
""" """
Copy this function. Copied function will have separated maker and Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate fgraph with original function. User can choose whether to separate
...@@ -562,6 +563,11 @@ class Function(object): ...@@ -562,6 +563,11 @@ class Function(object):
delete_updates -- { boolean } Default is False. If True, Copied delete_updates -- { boolean } Default is False. If True, Copied
function will not have update. function will not have update.
name -- { string } If provided, will be the name of the new
Function. Otherwise, it will be old + " copy"
profile -- as theano.function profile parameter
--------------------- ---------------------
Returns: Returns:
func -- Copied theano.Function func -- Copied theano.Function
...@@ -664,11 +670,26 @@ class Function(object): ...@@ -664,11 +670,26 @@ class Function(object):
for key in storage_map.keys(): for key in storage_map.keys():
if key not in i_o_vars: if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key] new_storage_map[memo[key]] = storage_map[key]
if not name and self.name:
name = self.name + " copy"
input_storage = [i.value for i in ins] input_storage = [i.value for i in ins]
# reinitialize new maker and create new function # reinitialize new maker and create new function
if profile is None:
profile = config.profile
# profile -> True or False
if profile is True:
if name:
message = name
else:
message = str(maker.profile.message) + " copy"
profile = theano.compile.profiling.ProfileStats(message=message)
# profile -> object
elif type(profile) == str:
profile = theano.compile.profiling.ProfileStats(message=profile)
f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy, f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
mode=maker.mode, profile=maker.profile, mode=maker.mode, profile=profile,
on_unused_input=maker.on_unused_input, on_unused_input=maker.on_unused_input,
function_builder=maker.function_builder, function_builder=maker.function_builder,
accept_inplace=maker.accept_inplace accept_inplace=maker.accept_inplace
...@@ -699,6 +720,8 @@ class Function(object): ...@@ -699,6 +720,8 @@ class Function(object):
f_cpy.finder[swap[in_ori.variable]] = container f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable] in_cpy.variable = swap[in_ori.variable]
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy return f_cpy
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
......
...@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if profile is True: if profile is True:
profile = ProfileStats(message=name) profile = ProfileStats(message=name)
# profile -> object # profile -> object
if type(profile) == str: elif type(profile) == str:
profile = ProfileStats(message=profile) profile = ProfileStats(message=profile)
# profile is typically either False or an object at this point. # profile is typically either False or an object at this point.
# No need to block other objects being passed through though. It might be # No need to block other objects being passed through though. It might be
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论