提交 2e45b019 authored 作者: Frederic Bastien's avatar Frederic Bastien

pickle function name correctly.

上级 e59aaf76
...@@ -2189,7 +2189,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2189,7 +2189,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
profile=None, profile=None,
on_unused_input=None, on_unused_input=None,
fgraph=None, # If present the optimized graph. we ignore it. fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None): output_keys=None,
name=None):
self.profile = profile self.profile = profile
optimizer = mode.optimizer optimizer = mode.optimizer
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
...@@ -2320,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2320,6 +2321,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.mode = mode self.mode = mode
self.on_unused_input = on_unused_input # Used for the pickling/copy self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys self.output_keys = output_keys
self.name = name
def create(self, defaults=None, trustme=False, storage_map=None): def create(self, defaults=None, trustme=False, storage_map=None):
""" """
...@@ -2406,7 +2408,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2406,7 +2408,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
storage_map=storage_map) storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices, fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single, self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self,
name=self.name)
return fn return fn
......
...@@ -359,7 +359,8 @@ class Function(object): ...@@ -359,7 +359,8 @@ class Function(object):
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, def __init__(self, fn, input_storage, output_storage, indices, outputs,
defaults, unpack_single, return_none, output_keys, maker): defaults, unpack_single, return_none, output_keys, maker,
name=None):
self.fn = fn self.fn = fn
self.input_storage = input_storage self.input_storage = input_storage
self.output_storage = output_storage self.output_storage = output_storage
...@@ -371,7 +372,7 @@ class Function(object): ...@@ -371,7 +372,7 @@ class Function(object):
self.maker = maker self.maker = maker
self.profile = None # reassigned in FunctionMaker.create self.profile = None # reassigned in FunctionMaker.create
self.trust_input = False # If True, we don't check the input parameter self.trust_input = False # If True, we don't check the input parameter
self.name = None self.name = name
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_keys = output_keys self.output_keys = output_keys
...@@ -1201,6 +1202,9 @@ class FunctionMaker(object): ...@@ -1201,6 +1202,9 @@ class FunctionMaker(object):
- 'warn': log a warning - 'warn': log a warning
- 'ignore': do not do anything - 'ignore': do not do anything
- None: Use the value in the Theano flags on_unused_input. - None: Use the value in the Theano flags on_unused_input.
name : str
An optional name for this function. If used, the profile mode will
print the time spent in this function.
""" """
...@@ -1415,7 +1419,7 @@ class FunctionMaker(object): ...@@ -1415,7 +1419,7 @@ class FunctionMaker(object):
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode=None, accept_inplace=False, function_builder=Function, mode=None, accept_inplace=False, function_builder=Function,
profile=None, on_unused_input=None, fgraph=None, profile=None, on_unused_input=None, fgraph=None,
output_keys=None): output_keys=None, name=None):
mode = theano.compile.mode.get_mode(mode) mode = theano.compile.mode.get_mode(mode)
# Assert old way of working isn't used # Assert old way of working isn't used
...@@ -1554,7 +1558,7 @@ class FunctionMaker(object): ...@@ -1554,7 +1558,7 @@ class FunctionMaker(object):
# hacky thing so VMLinker knows about updates # hacky thing so VMLinker knows about updates
self.linker.accept_var_updates( self.linker.accept_var_updates(
fgraph_updated_vars(fgraph, inputs)) fgraph_updated_vars(fgraph, inputs))
fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
self.expanded_inputs = inputs self.expanded_inputs = inputs
...@@ -1566,6 +1570,7 @@ class FunctionMaker(object): ...@@ -1566,6 +1570,7 @@ class FunctionMaker(object):
self.function_builder = function_builder self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used for the pickling/copy self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys self.output_keys = output_keys
self.name = name
self.required = [(i.value is None) for i in self.inputs] self.required = [(i.value is None) for i in self.inputs]
self.refeed = [ self.refeed = [
...@@ -1712,7 +1717,8 @@ class FunctionMaker(object): ...@@ -1712,7 +1717,8 @@ class FunctionMaker(object):
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self,
name=self.name)
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1817,7 +1823,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1817,7 +1823,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys) output_keys=output_keys,
name=name)
with theano.configparser.change_flags(compute_test_value="off"): with theano.configparser.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
finally: finally:
...@@ -1827,8 +1834,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1827,8 +1834,6 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
# TODO: append # TODO: append
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes) profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
fn.name = name
fn.maker.fgraph.name = name
return fn return fn
......
...@@ -610,6 +610,8 @@ class T_picklefunction(unittest.TestCase): ...@@ -610,6 +610,8 @@ class T_picklefunction(unittest.TestCase):
self.assertFalse(x in g.value) self.assertFalse(x in g.value)
self.assertTrue(len(f.defaults) == len(g.defaults)) self.assertTrue(len(f.defaults) == len(g.defaults))
self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs) self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs)
self.assertTrue(f.name == g.name)
self.assertTrue(f.maker.fgraph.name == f.maker.fgraph.name)
# print 'f.defaults = %s' % (f.defaults, ) # print 'f.defaults = %s' % (f.defaults, )
# print 'g.defaults = %s' % (g.defaults, ) # print 'g.defaults = %s' % (g.defaults, )
self.assertTrue(all([f_req == g_req and f_feed == g_feed and self.assertTrue(all([f_req == g_req and f_feed == g_feed and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论