提交 4c49de7f authored 作者: Frederic Bastien's avatar Frederic Bastien

Make test for cast to/from int/float16 and make sure to have a cast in the inner…

Make test for cast to/from int/float16 and make sure to have a cast in the inner graph instead of identify. (it was still working with identity, but having an explicit cast is better)
上级 0e595433
...@@ -78,22 +78,33 @@ class test_float16(): ...@@ -78,22 +78,33 @@ class test_float16():
def test_cast_float16(self): def test_cast_float16(self):
f16 = theano.tensor.vector(dtype='float16') f16 = theano.tensor.vector(dtype='float16')
f32 = theano.tensor.fvector() f32 = theano.tensor.fvector()
f = theano.function([f16, f32], i8 = theano.tensor.bvector()
f = theano.function([f16, f32, i8],
[f16.astype('float32'), [f16.astype('float32'),
f32.astype('float16'), f32.astype('float16'),
f32.astype('float64')], f32.astype('float64'),
f16.astype('int8'),
f32.astype('int8'),
i8.astype('float16'),
i8.astype('float32')],
mode=mode_with_gpu) mode=mode_with_gpu)
d1 = numpy.random.rand(4).astype('float16') d1 = (numpy.random.rand(4) * 10).astype('float16')
d2 = numpy.random.rand(5).astype('float32') d2 = (numpy.random.rand(5) * 10).astype('float32')
res = f(d1, d2) d3 = (numpy.random.rand(6) * 10).astype('int8')
res = f(d1, d2, d3)
assert res[0].dtype == "float32"
assert res[1].dtype == "float16" for i, out in enumerate(f.outputs):
assert res[2].dtype == "float64" dtype = out.variable.dtype
assert_allclose(d1, res[0]) assert res[i].dtype == dtype
assert_allclose(d2, res[1]) inp = out.variable.owner.inputs[0]
assert_allclose(d2, res[2]) if inp.dtype == 'float16':
d = d1
elif inp.dtype == 'float32':
d = d2
else:
d = d3
assert_allclose(d.astype(dtype), res[i])
class test_GpuDimShuffle(test_elemwise.test_DimShuffle): class test_GpuDimShuffle(test_elemwise.test_DimShuffle):
......
...@@ -2222,7 +2222,7 @@ class Cast(UnaryScalarOp): ...@@ -2222,7 +2222,7 @@ class Cast(UnaryScalarOp):
def clone_float32(self): def clone_float32(self):
if self.o_type == float16: if self.o_type == float16:
return identity return convert_to_float32
return self return self
def make_new_inplace(self, output_types_preference=None, name=None): def make_new_inplace(self, output_types_preference=None, name=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论