提交 9c262caa authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5442 from affanv14/bugfix

Fix bug caused when setting default value in Param in interactive mode
...@@ -749,6 +749,12 @@ class Function(object): ...@@ -749,6 +749,12 @@ class Function(object):
List of outputs on indices/keys from ``output_subset`` or all of them, List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed. if ``output_subset`` is not passed.
""" """
def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value
profile = self.profile profile = self.profile
t0 = time.time() t0 = time.time()
...@@ -804,6 +810,7 @@ class Function(object): ...@@ -804,6 +810,7 @@ class Function(object):
e.args = ("Bad input " + argument_name + " to " + e.args = ("Bad input " + argument_name + " to " +
function_name + " at index %d (0-based). %s" function_name + " at index %d (0-based). %s"
% (i, where),) + e.args % (i, where),) + e.args
restore_defaults()
raise raise
s.provided += 1 s.provided += 1
i += 1 i += 1
...@@ -853,14 +860,17 @@ class Function(object): ...@@ -853,14 +860,17 @@ class Function(object):
if not self.trust_input: if not self.trust_input:
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
restore_defaults()
raise TypeError("Missing required input: %s" % raise TypeError("Missing required input: %s" %
getattr(self.inv_finder[c], 'variable', getattr(self.inv_finder[c], 'variable',
self.inv_finder[c])) self.inv_finder[c]))
if c.provided > 1: if c.provided > 1:
restore_defaults()
raise TypeError("Multiple values for input: %s" % raise TypeError("Multiple values for input: %s" %
getattr(self.inv_finder[c], 'variable', getattr(self.inv_finder[c], 'variable',
self.inv_finder[c])) self.inv_finder[c]))
if c.implicit and c.provided > 0: if c.implicit and c.provided > 0:
restore_defaults()
raise TypeError( raise TypeError(
'Tried to provide value for implicit input: %s' 'Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable', % getattr(self.inv_finder[c], 'variable',
...@@ -873,6 +883,7 @@ class Function(object): ...@@ -873,6 +883,7 @@ class Function(object):
self.fn() if output_subset is None else\ self.fn() if output_subset is None else\
self.fn(output_subset=output_subset) self.fn(output_subset=output_subset)
except Exception: except Exception:
restore_defaults()
if hasattr(self.fn, 'position_of_error'): if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function or c linker # this is a new vm-provided function or c linker
# they need this because the exception manipulation # they need this because the exception manipulation
...@@ -925,11 +936,7 @@ class Function(object): ...@@ -925,11 +936,7 @@ class Function(object):
outputs = outputs[:self.n_returned_outputs] outputs = outputs[:self.n_returned_outputs]
# Put default values back in the storage # Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults): restore_defaults()
if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value
# #
# NOTE: This logic needs to be replicated in # NOTE: This logic needs to be replicated in
# scan. # scan.
......
...@@ -579,6 +579,20 @@ class T_function(unittest.TestCase): ...@@ -579,6 +579,20 @@ class T_function(unittest.TestCase):
if not isinstance(key, theano.gof.Constant): if not isinstance(key, theano.gof.Constant):
assert (val[0] is None) assert (val[0] is None)
def test_default_values(self):
"""
Check that default values are restored
when an exception occurs in interactive mode.
"""
a, b = T.dscalars('a', 'b')
c = a + b
func = theano.function([theano.In(a, name='first'), theano.In(b, value=1, name='second')], c)
x = func(first=1)
try:
func(second=2)
except TypeError:
assert(func(first=1) == x)
class T_picklefunction(unittest.TestCase): class T_picklefunction(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论