提交 9d28f392 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed how I deal with return values from the lambda expression

This follows the discussion that I had with Pascal, that tries to simplify how scan deals with what the user gives it. In that direction I first depricated some of the possibilities the user had, in order to get better clarity in the code.
上级 349848f4
...@@ -155,135 +155,92 @@ def clone( output ...@@ -155,135 +155,92 @@ def clone( output
def get_updates_and_outputs(outputs_updates): def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates dictionary and the This function tries to recognize the updates dictionary, the
list of outputs from the input argument and return them in a list of outputs and the stopping condition returned by the
predefined order lambda expression and arrange them in a predefined order
The code that follows tries to be as flexible as possible allowing the
user to return the output and updates in any order, and giving the
updates however (s)he wants ( as a dictionary or a list o pairs ..)
Is there a way to compress all this by writing it in a more
pythonic/functional way?
""" """
outputs = [] def is_outputs(elem):
updates = {} if (isinstance(elem, (list,tuple)) and
cond = None all([isinstance(x, theano.Variable) for x in elem])):
return True
def pick_from2(elem0, elem1): if isinstance(elem, theano.Variable):
lupd = {} return True
lout = [] return False
if ( isinstance(elem0,dict) or
( isinstance(elem0, (list,tuple)) and def is_updates(elem):
isinstance(elem0[0], (list,tuple)))): if isinstance(elem, dict):
# elem0 is the updates dictionary / list return True
lupd = dict(elem0) # Dictionaries can be given as lists of tuples
lout = elem1 if (isinstance(elem, (list, tuple)) and
if not isinstance(outputs, (list,tuple)): all([isinstance(x, (list,tuple)) and len(x) ==2
lout = [outputs] for x in elem])):
elif ( isinstance(elem1, dict) or return True
( isinstance(elem1, (list,tuple)) and return False
isinstance(elem1[0], (list,tuple))) ):
# elem1 is the updates dictionary / list def is_condition(elem):
lupd = dict(elem1) return isinstance(elem, theano.scan_module.until)
lout = elem0
if not isinstance(outputs, (list,tuple)): def _list(x):
lout = [outputs] if isinstance(x, (list, tuple)):
else : return list(x)
if ( isinstance(outputs_updates, (list,tuple)) and
isinstance(outputs_updates[0], (list,tuple))):
lout = []
lupd = dict(outputs_updates)
else: else:
lout = outputs_updates return [x]
lupd = {}
return lupd, lout if is_outputs(ls):
return None, _list(ls), {}
def pick_from1(elem0): if is_updates(ls):
lupd = {} return None, [], dict(ls)
lout = [] if not isinstance(ls, (list, tuple)):
if ( isinstance(elem0, dict) or raise ValueError(('Scan can not parse the return value'
(isinstance(elem0, (list,tuple)) and ' of your lambda expression'))
isinstance(elem0[0], (list, tuple)))): ls = list(ls)
lupd = dict(elem0) deprication_msg = ('The return value of the lambda function'
' has been restricted. you have to always return first the'
' outputs (if any), afterwards the updates (if any) and'
' at the end the conclusion')
error_msg = 'Scan can not parse the return value of your lambda expression'
if len(ls) == 2:
if is_outputs(ls[0]):
if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1]))
elif is_condition(ls[1]):
return ( ls[1].condition, _list(ls[0]), {})
else: else:
if not isinstance(elem0, (list, tuple)): raise ValueError(error_msg)
lout = [elem0] elif is_updates(ls[0]):
if is_outputs(ls[1]):
_logger.warning(deprication_msg)
return ( None, _list(ls[1]), dict(ls[0]) )
elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0]))
else: else:
lout = elem0 raise ValueError(error_msg)
return lupd, lout
# we will try now to separate the outputs from the updates
if not isinstance(outputs_updates, (list,tuple)):
if isinstance(outputs_updates, dict) :
# we have just an update dictionary
updates = outputs_updates
elif isinstance(outputs_updates, until):
updates = outputs_updates.updates
outputs = outputs_updates.outputs
cond = outputs_updates.condition
else: else:
outputs = [outputs_updates] raise ValueError(error_msg)
elif len(outputs_updates) == 1: elif len(ls) == 3:
rval = pick_from1(outputs_updates) if is_outputs(ls[0]):
updates = rval[0] if is_updates(ls[1]):
outputs = rval[1] if is_condition(ls[2]):
elif len(outputs_updates) == 2: return (ls[2].condition, _list(ls[0]), dict(ls[1]))
elem0 = outputs_updates[0]
elem1 = outputs_updates[1]
if isinstance(elem0,until):
cond = elem0.condition
rval = pick_from1(elem1)
updates = rval[0].updates(elem0.updates)
outputs = rval[1] + elem0.outputs
elif isinstance(elem1, until):
cond = elem1.condition
rval = pick_from1(elem0)
updates = rval[0].update(elem1.updates)
outputs = rval[1] + elem1.outputs
else: else:
rval = pick_from2(elem0, elem1) raise ValueError(error_msg)
updates = rval[0]
outputs = rval[1]
elif len(outputs_updates) == 3:
elem0 = outputs_updates[0]
elem1 = outputs_updates[1]
elem2 = outputs_updates[2]
if isinstance(elem0, until):
cond = elem0.condition
rval = pick_from2(elem1, elem2)
updates = rval[0].update(elem0.updates)
outputs = rval[1] + elem0.outputs
elif isinstance(elem1, until):
cond = elem1.condition
rval = pick_from2(elem0, elem2)
updates = rval[0].update(elem1.updates)
outputs = rval[1] + elem1.outputs
elif isinstance(elem2, until):
cond = elem2.condition
rval = pick_from2(elem0, elem1)
updates = rval[0].update(elem2.updates)
outputs = rval[1] + elem2.outputs
else: else:
outputs = outputs_updates raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
if is_condition(ls[2]):
_logger.warning(deprication_msg)
return (ls[2].condition, _list(ls[1]), dict(ls[0]))
else: else:
outputs = outputs_updates raise ValueError(error_msg)
# in case you return a tuple .. convert it to a list (there are certain
# operation that are not permited on tuples, like element assignment)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
else: else:
outputs = list(outputs) raise ValueError(error_msg)
else:
# If you return numbers (highly unlikely) this will not go well for raise ValueError(error_msg)
# theano. We need to convert them to Theano constants:
for i,out in enumerate(outputs):
outputs[i] = tensor.as_tensor(out)
return cond, outputs, updates
def isNaN_or_Inf_or_None(x): def isNaN_or_Inf_or_None(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论