提交 628b9720 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In ifelse, convert "then" and "else" to TensorType

If one of the "then" elements is a TensorType, and the corresponding element in the "else" branch is a compatible type, or vice versa, convert it into TensorType, so both have the same Type, and the output Type can be determined. This should fix a problem reported by Ian. [Commit amended to fix problem spotted by Fred]
上级 f2a73ea1
...@@ -26,6 +26,7 @@ import logging ...@@ -26,6 +26,7 @@ import logging
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
import theano.tensor import theano.tensor
from theano.tensor import TensorType
import gof import gof
from compile import optdb from compile import optdb
...@@ -312,7 +313,10 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -312,7 +313,10 @@ def ifelse(condition, then_branch, else_branch, name=None):
if type(else_branch) not in (list, tuple): if type(else_branch) not in (list, tuple):
else_branch = [else_branch] else_branch = [else_branch]
# Some of the elements might be converted into another type,
# we will store them in these new_... lists.
new_then_branch = []
new_else_branch = []
for then_branch_elem, else_branch_elem in zip(then_branch, else_branch): for then_branch_elem, else_branch_elem in zip(then_branch, else_branch):
if not isinstance(then_branch_elem, theano.Variable): if not isinstance(then_branch_elem, theano.Variable):
then_branch_elem = theano.tensor.as_tensor_variable(then_branch_elem) then_branch_elem = theano.tensor.as_tensor_variable(then_branch_elem)
...@@ -320,14 +324,32 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -320,14 +324,32 @@ def ifelse(condition, then_branch, else_branch, name=None):
else_branch_elem = theano.tensor.as_tensor_variable(else_branch_elem) else_branch_elem = theano.tensor.as_tensor_variable(else_branch_elem)
if then_branch_elem.type != else_branch_elem.type: if then_branch_elem.type != else_branch_elem.type:
raise ValueError(('The two branches should have identical types, ' # If one of them is a TensorType, and the other one can be
# converted into one, then we try to do that.
# This case happens when one of the elements has a GPU type,
# for instance a shared variable that was silently moved to GPU.
if (isinstance(then_branch_elem.type, TensorType)
and not isinstance(else_branch_elem.type, TensorType)):
else_branch_elem = then_branch_elem.type.filter_variable(
else_branch_elem)
elif (isinstance(else_branch_elem.type, TensorType)
and not isinstance(then_branch_elem.type, TensorType)):
then_branch_elem = else_branch_elem.type.filter_variable(
then_branch_elem)
if then_branch_elem.type != else_branch_elem.type:
# If the types still don't match, there is a problem.
raise ValueError(
('The two branches should have identical types, '
' but they are '+str(then_branch_elem.type)+' and '+ ' but they are '+str(then_branch_elem.type)+' and '+
str(else_branch_elem.type)+' respectively. ' str(else_branch_elem.type)+' respectively. '
'This error could be raised if for example ' 'This error could be raised if for example '
' you provided a one element list on the then ' ' you provided a one element list on the then '
' branch but a tensor on the else branch')) ' branch but a tensor on the else branch'))
new_then_branch.append(then_branch_elem)
new_else_branch.append(else_branch_elem)
if len(then_branch) != len(else_branch): if len(then_branch) != len(else_branch):
raise ValueError(('The number of values on the `then` branch' raise ValueError(('The number of values on the `then` branch'
...@@ -341,7 +363,7 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -341,7 +363,7 @@ def ifelse(condition, then_branch, else_branch, name=None):
gpu=False, gpu=False,
name=name) name=name)
ins = [condition] + list(then_branch) + list(else_branch) ins = [condition] + list(new_then_branch) + list(new_else_branch)
rval = new_ifelse.make_node(*ins).outputs rval = new_ifelse.make_node(*ins).outputs
if rval_type is None: if rval_type is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论