提交 9d28f392 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed how I deal with return values from the lambda expression

This follows the discussion that I had with Pascal, that tries to simplify how scan deals with what the user gives it. In that direction I first depricated some of the possibilities the user had, in order to get better clarity in the code.
上级 349848f4
...@@ -155,135 +155,92 @@ def clone( output ...@@ -155,135 +155,92 @@ def clone( output
def get_updates_and_outputs(outputs_updates): def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates dictionary and the This function tries to recognize the updates dictionary, the
list of outputs from the input argument and return them in a list of outputs and the stopping condition returned by the
predefined order lambda expression and arrange them in a predefined order
The code that follows tries to be as flexible as possible allowing the
user to return the output and updates in any order, and giving the
updates however (s)he wants ( as a dictionary or a list o pairs ..)
Is there a way to compress all this by writing it in a more
pythonic/functional way?
""" """
outputs = [] def is_outputs(elem):
updates = {} if (isinstance(elem, (list,tuple)) and
cond = None all([isinstance(x, theano.Variable) for x in elem])):
return True
def pick_from2(elem0, elem1): if isinstance(elem, theano.Variable):
lupd = {} return True
lout = [] return False
if ( isinstance(elem0,dict) or
( isinstance(elem0, (list,tuple)) and def is_updates(elem):
isinstance(elem0[0], (list,tuple)))): if isinstance(elem, dict):
# elem0 is the updates dictionary / list return True
lupd = dict(elem0) # Dictionaries can be given as lists of tuples
lout = elem1 if (isinstance(elem, (list, tuple)) and
if not isinstance(outputs, (list,tuple)): all([isinstance(x, (list,tuple)) and len(x) ==2
lout = [outputs] for x in elem])):
elif ( isinstance(elem1, dict) or return True
( isinstance(elem1, (list,tuple)) and return False
isinstance(elem1[0], (list,tuple))) ):
# elem1 is the updates dictionary / list def is_condition(elem):
lupd = dict(elem1) return isinstance(elem, theano.scan_module.until)
lout = elem0
if not isinstance(outputs, (list,tuple)): def _list(x):
lout = [outputs] if isinstance(x, (list, tuple)):
else : return list(x)
if ( isinstance(outputs_updates, (list,tuple)) and
isinstance(outputs_updates[0], (list,tuple))):
lout = []
lupd = dict(outputs_updates)
else:
lout = outputs_updates
lupd = {}
return lupd, lout
def pick_from1(elem0):
lupd = {}
lout = []
if ( isinstance(elem0, dict) or
(isinstance(elem0, (list,tuple)) and
isinstance(elem0[0], (list, tuple)))):
lupd = dict(elem0)
else: else:
if not isinstance(elem0, (list, tuple)): return [x]
lout = [elem0]
if is_outputs(ls):
return None, _list(ls), {}
if is_updates(ls):
return None, [], dict(ls)
if not isinstance(ls, (list, tuple)):
raise ValueError(('Scan can not parse the return value'
' of your lambda expression'))
ls = list(ls)
deprication_msg = ('The return value of the lambda function'
' has been restricted. you have to always return first the'
' outputs (if any), afterwards the updates (if any) and'
' at the end the conclusion')
error_msg = 'Scan can not parse the return value of your lambda expression'
if len(ls) == 2:
if is_outputs(ls[0]):
if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1]))
elif is_condition(ls[1]):
return ( ls[1].condition, _list(ls[0]), {})
else: else:
lout = elem0 raise ValueError(error_msg)
return lupd, lout elif is_updates(ls[0]):
if is_outputs(ls[1]):
# we will try now to separate the outputs from the updates _logger.warning(deprication_msg)
if not isinstance(outputs_updates, (list,tuple)): return ( None, _list(ls[1]), dict(ls[0]) )
if isinstance(outputs_updates, dict) : elif is_condition(ls[1]):
# we have just an update dictionary return (ls[1].condition, [], dict(ls[0]))
updates = outputs_updates else:
elif isinstance(outputs_updates, until): raise ValueError(error_msg)
updates = outputs_updates.updates
outputs = outputs_updates.outputs
cond = outputs_updates.condition
else:
outputs = [outputs_updates]
elif len(outputs_updates) == 1:
rval = pick_from1(outputs_updates)
updates = rval[0]
outputs = rval[1]
elif len(outputs_updates) == 2:
elem0 = outputs_updates[0]
elem1 = outputs_updates[1]
if isinstance(elem0,until):
cond = elem0.condition
rval = pick_from1(elem1)
updates = rval[0].updates(elem0.updates)
outputs = rval[1] + elem0.outputs
elif isinstance(elem1, until):
cond = elem1.condition
rval = pick_from1(elem0)
updates = rval[0].update(elem1.updates)
outputs = rval[1] + elem1.outputs
else: else:
rval = pick_from2(elem0, elem1) raise ValueError(error_msg)
updates = rval[0] elif len(ls) == 3:
outputs = rval[1] if is_outputs(ls[0]):
elif len(outputs_updates) == 3: if is_updates(ls[1]):
elem0 = outputs_updates[0] if is_condition(ls[2]):
elem1 = outputs_updates[1] return (ls[2].condition, _list(ls[0]), dict(ls[1]))
elem2 = outputs_updates[2] else:
if isinstance(elem0, until): raise ValueError(error_msg)
cond = elem0.condition else:
rval = pick_from2(elem1, elem2) raise ValueError(error_msg)
updates = rval[0].update(elem0.updates) elif is_updates(ls[0]):
outputs = rval[1] + elem0.outputs if is_outputs(ls[1]):
elif isinstance(elem1, until): if is_condition(ls[2]):
cond = elem1.condition _logger.warning(deprication_msg)
rval = pick_from2(elem0, elem2) return (ls[2].condition, _list(ls[1]), dict(ls[0]))
updates = rval[0].update(elem1.updates) else:
outputs = rval[1] + elem1.outputs raise ValueError(error_msg)
elif isinstance(elem2, until): else:
cond = elem2.condition raise ValueError(error_msg)
rval = pick_from2(elem0, elem1)
updates = rval[0].update(elem2.updates)
outputs = rval[1] + elem2.outputs
else: else:
outputs = outputs_updates raise ValueError(error_msg)
else:
outputs = outputs_updates
# in case you return a tuple .. convert it to a list (there are certain
# operation that are not permited on tuples, like element assignment)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
else:
outputs = list(outputs)
# If you return numbers (highly unlikely) this will not go well for
# theano. We need to convert them to Theano constants:
for i,out in enumerate(outputs):
outputs[i] = tensor.as_tensor(out)
return cond, outputs, updates
def isNaN_or_Inf_or_None(x): def isNaN_or_Inf_or_None(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论