提交 f2a73ea1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Test for mismatched types in then/else branches.

上级 95470852
...@@ -2,6 +2,10 @@ import numpy ...@@ -2,6 +2,10 @@ import numpy
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano import theano
from theano import tensor
from theano.ifelse import ifelse
from theano.tensor import TensorType
from theano.sandbox.cuda.var import float32_shared_constructor as f32sc from theano.sandbox.cuda.var import float32_shared_constructor as f32sc
from theano.sandbox.cuda import CudaNdarrayType, cuda_available from theano.sandbox.cuda import CudaNdarrayType, cuda_available
...@@ -61,9 +65,32 @@ def test_updates2(): ...@@ -61,9 +65,32 @@ def test_updates2():
output_var = f32sc(name="output", output_var = f32sc(name="output",
value=numpy.zeros((10,10), 'float32')) value=numpy.zeros((10,10), 'float32'))
x = theano.tensor.fmatrix('x') x = tensor.fmatrix('x')
output_updates = {output_var:x**2} output_updates = {output_var:x**2}
output_givens = {x:data} output_givens = {x:data}
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(inputs=[], outputs=[],
updates=output_updates, givens=output_givens) updates=output_updates, givens=output_givens)
output_func() output_func()
def test_ifelse():
data = numpy.float32([1,2,3,4])
x = f32sc(data)
y = x + 1
cond = theano.tensor.iscalar('cond')
assert isinstance(x.type, CudaNdarrayType)
assert isinstance(y.type, TensorType)
out1 = ifelse(cond, x, x+1)
out2 = ifelse(cond, x+1, x)
assert isinstance(out1.type, TensorType)
assert isinstance(out2.type, TensorType)
f = theano.function([cond], out1)
g = theano.function([cond], out2)
assert numpy.all(f(0) == data+1)
assert numpy.all(f(1) == data)
assert numpy.all(g(0) == data)
assert numpy.all(g(1) == data+1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论