提交 9c3f95cc authored 作者: Ian Goodfellow's avatar Ian Goodfellow

rewrote test_casting to not use T.value

上级 caf42a91
......@@ -4,6 +4,8 @@ from theano import function
from theano.tensor.basic import (_convert_to_int32, _convert_to_int8, _convert_to_int16,
_convert_to_int64, _convert_to_float32, _convert_to_float64)
from theano.tensor import *
from theano import shared
value = shared
class test_casting(unittest.TestCase):
......@@ -39,43 +41,45 @@ class test_casting(unittest.TestCase):
self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2)))
def test_convert_to_complex(self):
a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128')+0.5j)
val64 = numpy.ones(3, dtype='complex64') + 0.5j
val128 = numpy.ones(3, dtype='complex128') + 0.5j
f = function([a],basic._convert_to_complex128(a))
vec64 = TensorType('complex64',(False,))()
vec128 = TensorType('complex128',(False,))()
f = function([vec64],basic._convert_to_complex128(vec64))
#we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data))
f = function([b],basic._convert_to_complex128(b))
assert b.type.values_eq_approx(b.data, f(b.data))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(a.data, f(a.data))
f = function([b],basic._convert_to_complex64(b))
assert b.type.values_eq_approx(a.data, f(b.data))
for nbits in (64, 128):
# upcasting to complex128
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic._convert_to_complex128(a))
assert a.type.values_eq_approx(b.data, f(a.data))
# upcasting to complex64
for t in ['int8','int16','int32','int64','float32']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
# downcast to complex64
for t in ['float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
assert vec64.type.values_eq_approx(val128, f(val64))
f = function([vec128],basic._convert_to_complex128(vec128))
assert vec64.type.values_eq_approx(val128, f(val128))
f = function([vec64],basic._convert_to_complex64(vec64))
assert vec64.type.values_eq_approx(val64, f(val64))
f = function([vec128],basic._convert_to_complex64(vec128))
assert vec128.type.values_eq_approx(val64, f(val128))
# upcasting to complex128
for t in ['int8','int16','int32','int64','float32','float64']:
a = shared(numpy.ones(3, dtype=t))
b = shared(numpy.ones(3, dtype='complex128'))
f = function([],basic._convert_to_complex128(a))
assert a.type.values_eq_approx(b.get_value(), f())
# upcasting to complex64
for t in ['int8','int16','int32','int64','float32']:
a = shared(numpy.ones(3, dtype=t))
b = shared(numpy.ones(3, dtype='complex64'))
f = function([],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.get_value(), f())
# downcast to complex64
for t in ['float64']:
a = shared(numpy.ones(3, dtype=t))
b = shared(numpy.ones(3, dtype='complex64'))
f = function([],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.get_value(), f())
def test_bug_complext_10_august_09(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论