提交 9c262caa authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5442 from affanv14/bugfix

Fix bug caused when setting default value in Param in interactive mode
......@@ -749,6 +749,12 @@ class Function(object):
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value
profile = self.profile
t0 = time.time()
......@@ -804,6 +810,7 @@ class Function(object):
e.args = ("Bad input " + argument_name + " to " +
function_name + " at index %d (0-based). %s"
% (i, where),) + e.args
restore_defaults()
raise
s.provided += 1
i += 1
......@@ -853,14 +860,17 @@ class Function(object):
if not self.trust_input:
for c in self.input_storage:
if c.required and not c.provided:
restore_defaults()
raise TypeError("Missing required input: %s" %
getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.provided > 1:
restore_defaults()
raise TypeError("Multiple values for input: %s" %
getattr(self.inv_finder[c], 'variable',
self.inv_finder[c]))
if c.implicit and c.provided > 0:
restore_defaults()
raise TypeError(
'Tried to provide value for implicit input: %s'
% getattr(self.inv_finder[c], 'variable',
......@@ -873,6 +883,7 @@ class Function(object):
self.fn() if output_subset is None else\
self.fn(output_subset=output_subset)
except Exception:
restore_defaults()
if hasattr(self.fn, 'position_of_error'):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
......@@ -925,11 +936,7 @@ class Function(object):
outputs = outputs[:self.n_returned_outputs]
# Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, gof.Container):
value = value.storage[0]
self[i] = value
restore_defaults()
#
# NOTE: This logic needs to be replicated in
# scan.
......
......@@ -579,6 +579,20 @@ class T_function(unittest.TestCase):
if not isinstance(key, theano.gof.Constant):
assert (val[0] is None)
def test_default_values(self):
"""
Check that default values are restored
when an exception occurs in interactive mode.
"""
a, b = T.dscalars('a', 'b')
c = a + b
func = theano.function([theano.In(a, name='first'), theano.In(b, value=1, name='second')], c)
x = func(first=1)
try:
func(second=2)
except TypeError:
assert(func(first=1) == x)
class T_picklefunction(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论