提交 15694cfb authored 作者: Frederic Bastien's avatar Frederic Bastien

add test for the change to MakeVector.

上级 e4d5718c
...@@ -230,8 +230,14 @@ class MakeVector(T.Op): ...@@ -230,8 +230,14 @@ class MakeVector(T.Op):
dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs]) dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs])
#upcast the input to the determined dtype, but don't upcast downcast anything #upcast the input to the determined dtype, but don't upcast downcast anything
assert dtype==self.dtype, "Upcast the input of MakeVector to dtype gived in init without precissino loss only." assert dtype==self.dtype, "Upcast the input of MakeVector to dtype gived in init without precissino loss only."
if not all(self.dtype == T.cast(i,dtype=dtype).dtype for a in inputs):
raise TypeError("MakeVector.make_node expected inputs upcastable to %s. got %s"%(
self.dtype,
str([i.dtype for i in inputs])
))
inputs=[T.cast(i,dtype=dtype) for i in inputs] inputs=[T.cast(i,dtype=dtype) for i in inputs]
assert all(a.type == inputs[0].type for a in inputs) assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim==0 for a in inputs)
if inputs: if inputs:
dtype = inputs[0].type.dtype dtype = inputs[0].type.dtype
......
...@@ -1538,7 +1538,7 @@ class T_local_sum(unittest.TestCase): ...@@ -1538,7 +1538,7 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_all_to_none(self): def test_local_sum_all_to_none(self):
a = T.tensor3() a = T.tensor3()
input=numpy.arange(3*3*3).reshape(3,3,3) input=numpy.arange(3*3*3).reshape(3,3,3)
f = theano.function([a],a.sum()),mode=self.mode) f = theano.function([a],a.sum(),mode=self.mode)
assert len(f.maker.env.nodes)==1 assert len(f.maker.env.nodes)==1
assert numpy.allclose(f(input),input.sum()) assert numpy.allclose(f(input),input.sum())
...@@ -1632,6 +1632,33 @@ class T_local_sum_dimshuffle(unittest.TestCase): ...@@ -1632,6 +1632,33 @@ class T_local_sum_dimshuffle(unittest.TestCase):
# test_local_sum_prod_dimshuffle (a * b * c) # test_local_sum_prod_dimshuffle (a * b * c)
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d)) # test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
def test_make_vector_upcast():
b = T.bscalar()
i = T.iscalar()
d = T.dscalar()
opt.MakeVector(dtype="int8")(b,b)
opt.MakeVector(dtype="int32")(i,b)
opt.MakeVector(dtype="int32")(b,i)
opt.MakeVector(dtype="float64")(b,i)
opt.MakeVector(dtype="float64")(b,d)
opt.MakeVector(dtype="float64")(d,i)
#should fail
for (dtype,inputs) in [("int8",(b,i)),
("int8",(i,b)),
("int8",(b,d)),
("int8",(i,i)),
("int32",(d,i)),
("int32",(i,d)),
("float32",(i,d)),
]:
try:
opt.MakeVector(dtype=dtype)(*inputs)
raise Exception("Theano should have raised an error")
except AssertionError:
pass
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论