提交 9d28f392 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed how I deal with return values from the lambda expression

This follows the discussion that I had with Pascal, that tries to simplify how scan deals with what the user gives it. In that direction I first depricated some of the possibilities the user had, in order to get better clarity in the code.
上级 349848f4
......@@ -155,135 +155,92 @@ def clone( output
def get_updates_and_outputs(outputs_updates):
def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates dictionary and the
list of outputs from the input argument and return them in a
predefined order
This function tries to recognize the updates dictionary, the
list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order
The code that follows tries to be as flexible as possible allowing the
user to return the output and updates in any order, and giving the
updates however (s)he wants ( as a dictionary or a list o pairs ..)
Is there a way to compress all this by writing it in a more
pythonic/functional way?
"""
outputs = []
updates = {}
cond = None
def pick_from2(elem0, elem1):
lupd = {}
lout = []
if ( isinstance(elem0,dict) or
( isinstance(elem0, (list,tuple)) and
isinstance(elem0[0], (list,tuple)))):
# elem0 is the updates dictionary / list
lupd = dict(elem0)
lout = elem1
if not isinstance(outputs, (list,tuple)):
lout = [outputs]
elif ( isinstance(elem1, dict) or
( isinstance(elem1, (list,tuple)) and
isinstance(elem1[0], (list,tuple))) ):
# elem1 is the updates dictionary / list
lupd = dict(elem1)
lout = elem0
if not isinstance(outputs, (list,tuple)):
lout = [outputs]
else :
if ( isinstance(outputs_updates, (list,tuple)) and
isinstance(outputs_updates[0], (list,tuple))):
lout = []
lupd = dict(outputs_updates)
else:
lout = outputs_updates
lupd = {}
return lupd, lout
def pick_from1(elem0):
lupd = {}
lout = []
if ( isinstance(elem0, dict) or
(isinstance(elem0, (list,tuple)) and
isinstance(elem0[0], (list, tuple)))):
lupd = dict(elem0)
def is_outputs(elem):
if (isinstance(elem, (list,tuple)) and
all([isinstance(x, theano.Variable) for x in elem])):
return True
if isinstance(elem, theano.Variable):
return True
return False
def is_updates(elem):
if isinstance(elem, dict):
return True
# Dictionaries can be given as lists of tuples
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, (list,tuple)) and len(x) ==2
for x in elem])):
return True
return False
def is_condition(elem):
return isinstance(elem, theano.scan_module.until)
def _list(x):
if isinstance(x, (list, tuple)):
return list(x)
else:
if not isinstance(elem0, (list, tuple)):
lout = [elem0]
return [x]
if is_outputs(ls):
return None, _list(ls), {}
if is_updates(ls):
return None, [], dict(ls)
if not isinstance(ls, (list, tuple)):
raise ValueError(('Scan can not parse the return value'
' of your lambda expression'))
ls = list(ls)
deprication_msg = ('The return value of the lambda function'
' has been restricted. you have to always return first the'
' outputs (if any), afterwards the updates (if any) and'
' at the end the conclusion')
error_msg = 'Scan can not parse the return value of your lambda expression'
if len(ls) == 2:
if is_outputs(ls[0]):
if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1]))
elif is_condition(ls[1]):
return ( ls[1].condition, _list(ls[0]), {})
else:
lout = elem0
return lupd, lout
# we will try now to separate the outputs from the updates
if not isinstance(outputs_updates, (list,tuple)):
if isinstance(outputs_updates, dict) :
# we have just an update dictionary
updates = outputs_updates
elif isinstance(outputs_updates, until):
updates = outputs_updates.updates
outputs = outputs_updates.outputs
cond = outputs_updates.condition
else:
outputs = [outputs_updates]
elif len(outputs_updates) == 1:
rval = pick_from1(outputs_updates)
updates = rval[0]
outputs = rval[1]
elif len(outputs_updates) == 2:
elem0 = outputs_updates[0]
elem1 = outputs_updates[1]
if isinstance(elem0,until):
cond = elem0.condition
rval = pick_from1(elem1)
updates = rval[0].updates(elem0.updates)
outputs = rval[1] + elem0.outputs
elif isinstance(elem1, until):
cond = elem1.condition
rval = pick_from1(elem0)
updates = rval[0].update(elem1.updates)
outputs = rval[1] + elem1.outputs
raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
_logger.warning(deprication_msg)
return ( None, _list(ls[1]), dict(ls[0]) )
elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0]))
else:
raise ValueError(error_msg)
else:
rval = pick_from2(elem0, elem1)
updates = rval[0]
outputs = rval[1]
elif len(outputs_updates) == 3:
elem0 = outputs_updates[0]
elem1 = outputs_updates[1]
elem2 = outputs_updates[2]
if isinstance(elem0, until):
cond = elem0.condition
rval = pick_from2(elem1, elem2)
updates = rval[0].update(elem0.updates)
outputs = rval[1] + elem0.outputs
elif isinstance(elem1, until):
cond = elem1.condition
rval = pick_from2(elem0, elem2)
updates = rval[0].update(elem1.updates)
outputs = rval[1] + elem1.outputs
elif isinstance(elem2, until):
cond = elem2.condition
rval = pick_from2(elem0, elem1)
updates = rval[0].update(elem2.updates)
outputs = rval[1] + elem2.outputs
raise ValueError(error_msg)
elif len(ls) == 3:
if is_outputs(ls[0]):
if is_updates(ls[1]):
if is_condition(ls[2]):
return (ls[2].condition, _list(ls[0]), dict(ls[1]))
else:
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
if is_condition(ls[2]):
_logger.warning(deprication_msg)
return (ls[2].condition, _list(ls[1]), dict(ls[0]))
else:
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
else:
outputs = outputs_updates
else:
outputs = outputs_updates
# in case you return a tuple .. convert it to a list (there are certain
# operation that are not permited on tuples, like element assignment)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
else:
outputs = list(outputs)
# If you return numbers (highly unlikely) this will not go well for
# theano. We need to convert them to Theano constants:
for i,out in enumerate(outputs):
outputs[i] = tensor.as_tensor(out)
return cond, outputs, updates
raise ValueError(error_msg)
def isNaN_or_Inf_or_None(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论