提交 fd05e3a3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6419 from nouiz/debug_mode

Fix tests error in DebugMode. Make DebugMode more pickable
...@@ -1388,7 +1388,7 @@ class _FunctionGraphEvent(object): ...@@ -1388,7 +1388,7 @@ class _FunctionGraphEvent(object):
self.node = node self.node = node
self.op = node.op self.op = node.op
self.idx = idx self.idx = idx
self.reason = reason self.reason = str(reason)
def __str__(self): def __str__(self):
if self.kind == 'change': if self.kind == 'change':
...@@ -1472,7 +1472,7 @@ class _VariableEquivalenceTracker(object): ...@@ -1472,7 +1472,7 @@ class _VariableEquivalenceTracker(object):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node, self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason)) reason=str(reason)))
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
self.active_nodes.remove(node) self.active_nodes.remove(node)
...@@ -1480,7 +1480,7 @@ class _VariableEquivalenceTracker(object): ...@@ -1480,7 +1480,7 @@ class _VariableEquivalenceTracker(object):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('import', node, self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason)) reason=str(reason)))
assert node not in self.active_nodes assert node not in self.active_nodes
self.active_nodes.add(node) self.active_nodes.add(node)
...@@ -1501,8 +1501,9 @@ class _VariableEquivalenceTracker(object): ...@@ -1501,8 +1501,9 @@ class _VariableEquivalenceTracker(object):
self.replaced_by.setdefault(r, []) self.replaced_by.setdefault(r, [])
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
reason = str(reason)
self.event_list.append(_FunctionGraphEvent('change', node, self.event_list.append(_FunctionGraphEvent('change', node,
reason=str(reason), idx=i)) reason=reason, idx=i))
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, []) self.replaced_by.setdefault(new_r, [])
...@@ -2190,7 +2191,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2190,7 +2191,10 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph=None, # If present the optimized graph. we ignore it. fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None, output_keys=None,
name=None): name=None):
self.mode = mode
self.profile = profile self.profile = profile
if profile:
raise Exception("DebugMode do not support profiling.")
optimizer = mode.optimizer optimizer = mode.optimizer
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
# Variable (not in a list) # Variable (not in a list)
...@@ -2298,18 +2302,15 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2298,18 +2302,15 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# the 'no_borrow' outputs are the ones for which that we can't return # the 'no_borrow' outputs are the ones for which that we can't return
# the internal storage pointer. # the internal storage pointer.
no_borrow = [ no_borrow = [output for output, spec in
output izip(fgraph.outputs, outputs + additional_outputs)
for output, spec in izip(fgraph.outputs,
outputs + additional_outputs)
if not spec.borrow] if not spec.borrow]
if no_borrow: if no_borrow:
self.linker = linker.accept( self.linker = linker.accept(
fgraph, fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow))
no_recycling=infer_reuse_pattern(fgraph, no_borrow))
else: else:
self.linker = linker.accept(fgraph) self.linker = linker.accept(fgraph)
fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
self.expanded_inputs = inputs self.expanded_inputs = inputs
...@@ -2318,99 +2319,16 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2318,99 +2319,16 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.return_none = return_none self.return_none = return_none
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.mode = mode
self.on_unused_input = on_unused_input # Used for the pickling/copy self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys self.output_keys = output_keys
self.name = name self.name = name
def create(self, defaults=None, trustme=False, storage_map=None): self.required = [(i.value is None) for i in self.inputs]
""" self.refeed = [
Create a function. (i.value is not None and
not isinstance(i.value, gof.Container) and
Parameters i.update is None)
---------- for i in self.inputs]
defaults
A list matching the inputs list and providing default values if the
default for an input is None, then that input is a required input.
For an input with an update, the default acts as initialization.
trustme
Disables some exceptions, used internally.
"""
if defaults is None:
defaults = [None] * len(self.inputs)
# List of independent one-element lists, will be passed to the linker.
input_storage = []
_defaults = []
# The following loop is to fill in the input_storage and _defaults
# lists.
for (input, indices, subinputs), default in izip(self.indices,
defaults):
__default = default
if isinstance(default, gof.Container):
# If the default is a gof.Container, this means we want to
# share the same storage. This is done by appending
# default.storage to input_storage.
if indices is not None:
raise TypeError("Cannot take a Container instance as "
"default for a SymbolicInput.")
input_storage.append(default.storage)
default = None
else:
# Normal case: one new, independent storage unit
input_storage.append([None])
# Filling _defaults. Each entry is a tuple of three elements:
# (required, refeed, value)
# - required means that the user must provide a value when calling
# the function
# - refeed means that we want to put the default back in the
# storage after each function call
# - value is the value that will be put in the storage initially
if input.update is not None:
# If the input has an update, then (logically) it is
# not required since it is just a parameter and of
# course we don't want to refeed the default back into
# the storage as it would defeat the point of updating
# it. We always do this policy.
if default is None:
if trustme or isinstance(__default, gof.Container):
_defaults.append((False, False, None))
else:
# This might catch some bugs early
raise ValueError(
"A default (initial) value is required for an "
"input which can update itself.", input)
else:
_defaults.append((False, False, default))
else:
if default is None:
if trustme or isinstance(__default, gof.Container):
_defaults.append((False, False, None))
else:
# No default, so this is a required
# input. Nothing to feed back, initial value
# is None.
_defaults.append((True, False, None))
else:
# Default value. It is not required, but we want
# to put it back into the storage everytime so it
# behaves like most programming languages' default
# values
_defaults.append((False, True, default))
defaults = _defaults
# Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage,
storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self,
name=self.name)
return fn
######################## ########################
......
...@@ -1452,7 +1452,6 @@ class FunctionMaker(object): ...@@ -1452,7 +1452,6 @@ class FunctionMaker(object):
theano.gof.cc.get_module_cache().refresh() theano.gof.cc.get_module_cache().refresh()
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
# Variable (not in a list) # Variable (not in a list)
self.orig_outputs = outputs
unpack_single = False unpack_single = False
return_none = False return_none = False
if outputs is None: if outputs is None:
......
...@@ -645,7 +645,7 @@ class T_picklefunction(unittest.TestCase): ...@@ -645,7 +645,7 @@ class T_picklefunction(unittest.TestCase):
self.assertTrue(len(f.defaults) == len(g.defaults)) self.assertTrue(len(f.defaults) == len(g.defaults))
self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs) self.assertTrue(f._check_for_aliased_inputs is g._check_for_aliased_inputs)
self.assertTrue(f.name == g.name) self.assertTrue(f.name == g.name)
self.assertTrue(f.maker.fgraph.name == f.maker.fgraph.name) self.assertTrue(f.maker.fgraph.name == g.maker.fgraph.name)
# print 'f.defaults = %s' % (f.defaults, ) # print 'f.defaults = %s' % (f.defaults, )
# print 'g.defaults = %s' % (g.defaults, ) # print 'g.defaults = %s' % (g.defaults, )
self.assertTrue(all([f_req == g_req and f_feed == g_feed and self.assertTrue(all([f_req == g_req and f_feed == g_feed and
...@@ -681,9 +681,11 @@ class T_picklefunction(unittest.TestCase): ...@@ -681,9 +681,11 @@ class T_picklefunction(unittest.TestCase):
raise raise
self.assertTrue(f.trust_input is g.trust_input) self.assertTrue(f.trust_input is g.trust_input)
f(np.asarray(2.)) f(np.asarray(2.))
self.assertRaises((ValueError, AttributeError), f, 2.) self.assertRaises((ValueError, AttributeError,
theano.compile.debugmode.InvalidValueError), f, 2.)
g(np.asarray(2.)) g(np.asarray(2.))
self.assertRaises((ValueError, AttributeError), g, 2.) self.assertRaises((ValueError, AttributeError,
theano.compile.debugmode.InvalidValueError), g, 2.)
def test_output_keys(self): def test_output_keys(self):
x = T.vector() x = T.vector()
...@@ -1026,7 +1028,9 @@ def test_sync_update(): ...@@ -1026,7 +1028,9 @@ def test_sync_update():
f.sync_shared() f.sync_shared()
# Sync to make sure all computation are finished. # Sync to make sure all computation are finished.
t_2 = time.time() t_2 = time.time()
assert (t_1 - t_0) > (t_2 - t_1) d1 = (t_1 - t_0)
d2 = (t_2 - t_1)
assert d1 > d2, (d1, d2)
else: else:
raise SkipTest("Sync is only availble when pygpu is activated.") raise SkipTest("Sync is only availble when pygpu is activated.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论