提交 628b9720 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In ifelse, convert "then" and "else" to TensorType

If one of the "then" elements is a TensorType, and the corresponding element in the "else" branch is a compatible type, or vice versa, convert it into TensorType, so both have the same Type, and the output Type can be determined. This should fix a problem reported by Ian. [Commit amended to fix problem spotted by Fred]
上级 f2a73ea1
......@@ -26,6 +26,7 @@ import logging
from theano.gof import PureOp, Apply
import theano.tensor
from theano.tensor import TensorType
import gof
from compile import optdb
......@@ -312,7 +313,10 @@ def ifelse(condition, then_branch, else_branch, name=None):
if type(else_branch) not in (list, tuple):
else_branch = [else_branch]
# Some of the elements might be converted into another type,
# we will store them in these new_... lists.
new_then_branch = []
new_else_branch = []
for then_branch_elem, else_branch_elem in zip(then_branch, else_branch):
if not isinstance(then_branch_elem, theano.Variable):
then_branch_elem = theano.tensor.as_tensor_variable(then_branch_elem)
......@@ -320,14 +324,32 @@ def ifelse(condition, then_branch, else_branch, name=None):
else_branch_elem = theano.tensor.as_tensor_variable(else_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
raise ValueError(('The two branches should have identical types, '
# If one of them is a TensorType, and the other one can be
# converted into one, then we try to do that.
# This case happens when one of the elements has a GPU type,
# for instance a shared variable that was silently moved to GPU.
if (isinstance(then_branch_elem.type, TensorType)
and not isinstance(else_branch_elem.type, TensorType)):
else_branch_elem = then_branch_elem.type.filter_variable(
else_branch_elem)
elif (isinstance(else_branch_elem.type, TensorType)
and not isinstance(then_branch_elem.type, TensorType)):
then_branch_elem = else_branch_elem.type.filter_variable(
then_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
# If the types still don't match, there is a problem.
raise ValueError(
('The two branches should have identical types, '
' but they are '+str(then_branch_elem.type)+' and '+
str(else_branch_elem.type)+' respectively. '
'This error could be raised if for example '
' you provided a one element list on the then '
' branch but a tensor on the else branch'))
new_then_branch.append(then_branch_elem)
new_else_branch.append(else_branch_elem)
if len(then_branch) != len(else_branch):
raise ValueError(('The number of values on the `then` branch'
......@@ -341,7 +363,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
gpu=False,
name=name)
ins = [condition] + list(then_branch) + list(else_branch)
ins = [condition] + list(new_then_branch) + list(new_else_branch)
rval = new_ifelse.make_node(*ins).outputs
if rval_type is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论