提交 9c3f95cc authored 作者: Ian Goodfellow's avatar Ian Goodfellow

rewrote test_casting to not use T.value

上级 caf42a91
...@@ -4,6 +4,8 @@ from theano import function ...@@ -4,6 +4,8 @@ from theano import function
from theano.tensor.basic import (_convert_to_int32, _convert_to_int8, _convert_to_int16, from theano.tensor.basic import (_convert_to_int32, _convert_to_int8, _convert_to_int16,
_convert_to_int64, _convert_to_float32, _convert_to_float64) _convert_to_int64, _convert_to_float32, _convert_to_float64)
from theano.tensor import * from theano.tensor import *
from theano import shared
value = shared
class test_casting(unittest.TestCase): class test_casting(unittest.TestCase):
...@@ -39,43 +41,45 @@ class test_casting(unittest.TestCase): ...@@ -39,43 +41,45 @@ class test_casting(unittest.TestCase):
self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2))) self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2)))
def test_convert_to_complex(self): def test_convert_to_complex(self):
a = value(numpy.ones(3, dtype='complex64')+0.5j) val64 = numpy.ones(3, dtype='complex64') + 0.5j
b = value(numpy.ones(3, dtype='complex128')+0.5j) val128 = numpy.ones(3, dtype='complex128') + 0.5j
f = function([a],basic._convert_to_complex128(a)) vec64 = TensorType('complex64',(False,))()
vec128 = TensorType('complex128',(False,))()
f = function([vec64],basic._convert_to_complex128(vec64))
#we need to compare with the same type. #we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data)) assert vec64.type.values_eq_approx(val128, f(val64))
f = function([b],basic._convert_to_complex128(b)) f = function([vec128],basic._convert_to_complex128(vec128))
assert b.type.values_eq_approx(b.data, f(b.data)) assert vec64.type.values_eq_approx(val128, f(val128))
f = function([a],basic._convert_to_complex64(a)) f = function([vec64],basic._convert_to_complex64(vec64))
assert a.type.values_eq_approx(a.data, f(a.data)) assert vec64.type.values_eq_approx(val64, f(val64))
f = function([b],basic._convert_to_complex64(b)) f = function([vec128],basic._convert_to_complex64(vec128))
assert b.type.values_eq_approx(a.data, f(b.data)) assert vec128.type.values_eq_approx(val64, f(val128))
for nbits in (64, 128): # upcasting to complex128
# upcasting to complex128 for t in ['int8','int16','int32','int64','float32','float64']:
for t in ['int8','int16','int32','int64','float32','float64']: a = shared(numpy.ones(3, dtype=t))
a = value(numpy.ones(3, dtype=t)) b = shared(numpy.ones(3, dtype='complex128'))
b = value(numpy.ones(3, dtype='complex128')) f = function([],basic._convert_to_complex128(a))
f = function([a],basic._convert_to_complex128(a)) assert a.type.values_eq_approx(b.get_value(), f())
assert a.type.values_eq_approx(b.data, f(a.data))
# upcasting to complex64
# upcasting to complex64 for t in ['int8','int16','int32','int64','float32']:
for t in ['int8','int16','int32','int64','float32']: a = shared(numpy.ones(3, dtype=t))
a = value(numpy.ones(3, dtype=t)) b = shared(numpy.ones(3, dtype='complex64'))
b = value(numpy.ones(3, dtype='complex64')) f = function([],basic._convert_to_complex64(a))
f = function([a],basic._convert_to_complex64(a)) assert a.type.values_eq_approx(b.get_value(), f())
assert a.type.values_eq_approx(b.data, f(a.data))
# downcast to complex64
# downcast to complex64 for t in ['float64']:
for t in ['float64']: a = shared(numpy.ones(3, dtype=t))
a = value(numpy.ones(3, dtype=t)) b = shared(numpy.ones(3, dtype='complex64'))
b = value(numpy.ones(3, dtype='complex64')) f = function([],basic._convert_to_complex64(a))
f = function([a],basic._convert_to_complex64(a)) assert a.type.values_eq_approx(b.get_value(), f())
assert a.type.values_eq_approx(b.data, f(a.data))
def test_bug_complext_10_august_09(self): def test_bug_complext_10_august_09(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论