提交 d9805320 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed error handling and doc for ifelse

(there was a problem that would reject a pair of branches where one was shared and the other was not, regardless of whether they had the same type)
上级 9517a606
...@@ -275,7 +275,7 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -275,7 +275,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
If it evaluates to 0 it corresponds to False, anything else stands If it evaluates to 0 it corresponds to False, anything else stands
for True. for True.
:type then_branch: list of theano expressions/ theano expressions :type then_branch: list of theano expressions/ theano expression
:param then_branch: :param then_branch:
A single theano variable or a list of theano variables that the A single theano variable or a list of theano variables that the
function should return as the output if ``condition`` evaluates to function should return as the output if ``condition`` evaluates to
...@@ -300,11 +300,6 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -300,11 +300,6 @@ def ifelse(condition, then_branch, else_branch, name=None):
``then_branch`` or in the ``else_branch`` depending on the value of ``then_branch`` or in the ``else_branch`` depending on the value of
``cond``. ``cond``.
""" """
if type(then_branch) is not type(else_branch):
raise ValueError(('The two branches should be identical. '
'This error could be raised if for example '
' you provided a one element list on the then '
' branch but a tensor on the else branch'))
rval_type = None rval_type = None
if type(then_branch) is list: if type(then_branch) is list:
...@@ -317,6 +312,21 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -317,6 +312,21 @@ def ifelse(condition, then_branch, else_branch, name=None):
if type(else_branch) not in (list, tuple): if type(else_branch) not in (list, tuple):
else_branch = [else_branch] else_branch = [else_branch]
for then_branch_elem, else_branch_elem in zip(then_branch, else_branch):
then_branch_elem = theano.tensor.as_tensor_variable(then_branch_elem)
else_branch_elem = theano.tensor.as_tensor_variable(else_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
raise ValueError(('The two branches should have identical types, '
' but they are '+str(then_branch_elem.type)+' and '+
str(else_branch_elem.type)+' respectively. '
'This error could be raised if for example '
' you provided a one element list on the then '
' branch but a tensor on the else branch'))
if len(then_branch) != len(else_branch): if len(then_branch) != len(else_branch):
raise ValueError(('The number of values on the `then` branch' raise ValueError(('The number of values on the `then` branch'
' should have the same number of variables as ' ' should have the same number of variables as '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论