提交 2e45b019 authored 作者: Frederic Bastien's avatar Frederic Bastien

pickle function name correctly.

上级 e59aaf76
......@@ -2189,7 +2189,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
profile=None,
on_unused_input=None,
fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None):
output_keys=None,
name=None):
self.profile = profile
optimizer = mode.optimizer
# Handle the case where inputs and/or outputs is a single
......@@ -2320,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.mode = mode
self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys
self.name = name
def create(self, defaults=None, trustme=False, storage_map=None):
"""
......@@ -2406,7 +2408,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self)
self.return_none, self.output_keys, self,
name=self.name)
return fn
......
......@@ -359,7 +359,8 @@ class Function(object):
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, output_keys, maker):
defaults, unpack_single, return_none, output_keys, maker,
name=None):
self.fn = fn
self.input_storage = input_storage
self.output_storage = output_storage
......@@ -371,7 +372,7 @@ class Function(object):
self.maker = maker
self.profile = None # reassigned in FunctionMaker.create
self.trust_input = False # If True, we don't check the input parameter
self.name = None
self.name = name
self.nodes_with_inner_function = []
self.output_keys = output_keys
......@@ -1201,6 +1202,9 @@ class FunctionMaker(object):
- 'warn': log a warning
- 'ignore': do not do anything
- None: Use the value in the Theano flags on_unused_input.
name : str
An optional name for this function. If used, the profile mode will
print the time spent in this function.
"""
......@@ -1415,7 +1419,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None,
output_keys=None):
output_keys=None, name=None):
mode = theano.compile.mode.get_mode(mode)
# Assert old way of working isn't used
......@@ -1554,7 +1558,7 @@ class FunctionMaker(object):
# hacky thing so VMLinker knows about updates
self.linker.accept_var_updates(
fgraph_updated_vars(fgraph, inputs))
fgraph.name = name
self.indices = indices
self.inputs = inputs
self.expanded_inputs = inputs
......@@ -1566,6 +1570,7 @@ class FunctionMaker(object):
self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys
self.name = name
self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
......@@ -1712,7 +1717,8 @@ class FunctionMaker(object):
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single,
self.return_none, self.output_keys, self)
self.return_none, self.output_keys, self,
name=self.name)
fn.profile = self.profile
return fn
......@@ -1817,7 +1823,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys)
output_keys=output_keys,
name=name)
with theano.configparser.change_flags(compute_test_value="off"):
fn = m.create(defaults)
finally:
......@@ -1827,8 +1834,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
# TODO: append
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
fn.name = name
fn.maker.fgraph.name = name
return fn
......
......@@ -610,6 +610,8 @@ class T_picklefunction(unittest.TestCase):
self.assertFalse(x in g.value)
self.assertTrue(len(f.defaults) == len(g.defaults))
self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs)
self.assertTrue(f.name == g.name)
self.assertTrue(f.maker.fgraph.name == f.maker.fgraph.name)
# print 'f.defaults = %s' % (f.defaults, )
# print 'g.defaults = %s' % (g.defaults, )
self.assertTrue(all([f_req == g_req and f_feed == g_feed and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论