提交 69aaa256 authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #319 from goodfeli/ifelse_error

fixed error handling and doc for ifelse
......@@ -275,7 +275,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
If it evaluates to 0 it corresponds to False, anything else stands
for True.
:type then_branch: list of theano expressions/ theano expressions
:type then_branch: list of theano expressions/ theano expression
:param then_branch:
A single theano variable or a list of theano variables that the
function should return as the output if ``condition`` evaluates to
......@@ -300,11 +300,6 @@ def ifelse(condition, then_branch, else_branch, name=None):
``then_branch`` or in the ``else_branch`` depending on the value of
``cond``.
"""
if type(then_branch) is not type(else_branch):
raise ValueError(('The two branches should be identical. '
'This error could be raised if for example '
' you provided a one element list on the then '
' branch but a tensor on the else branch'))
rval_type = None
if type(then_branch) is list:
......@@ -317,6 +312,21 @@ def ifelse(condition, then_branch, else_branch, name=None):
if type(else_branch) not in (list, tuple):
else_branch = [else_branch]
for then_branch_elem, else_branch_elem in zip(then_branch, else_branch):
then_branch_elem = theano.tensor.as_tensor_variable(then_branch_elem)
else_branch_elem = theano.tensor.as_tensor_variable(else_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
raise ValueError(('The two branches should have identical types, '
' but they are '+str(then_branch_elem.type)+' and '+
str(else_branch_elem.type)+' respectively. '
'This error could be raised if for example '
' you provided a one element list on the then '
' branch but a tensor on the else branch'))
if len(then_branch) != len(else_branch):
raise ValueError(('The number of values on the `then` branch'
' should have the same number of variables as '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论