提交 f7bf373e authored 作者: James Bergstra's avatar James Bergstra

Added tests to tensor/tests/test_casting, and moved some from test_basic

上级 1c8a3067
...@@ -594,19 +594,6 @@ class T_Shape(unittest.TestCase): ...@@ -594,19 +594,6 @@ class T_Shape(unittest.TestCase):
s = shape(numpy.ones((5, 3, 10))) s = shape(numpy.ones((5, 3, 10)))
self.failUnless((eval_outputs([s]) == [5, 3, 10]).all()) self.failUnless((eval_outputs([s]) == [5, 3, 10]).all())
class T_Cast(unittest.TestCase):
def test_basic(self):
for type1 in ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']:
x = TensorType(dtype = type1, broadcastable = (False, )).make_variable()
for type2, converter in zip(['int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
[convert_to_int8, convert_to_int16, convert_to_int32, convert_to_int64,
convert_to_float32, convert_to_float64]):
y = converter(x)
f = inplace_func([compile.In(x, strict = True)], y)
a = numpy.arange(10, dtype = type1)
b = f(a)
self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2)))
class T_max_and_argmax(unittest.TestCase): class T_max_and_argmax(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
...@@ -1920,43 +1907,6 @@ def test_sum_overflow(): ...@@ -1920,43 +1907,6 @@ def test_sum_overflow():
f = function([a], sum(a)) f = function([a], sum(a))
assert f([1]*300) == 300 assert f([1]*300) == 300
def test_convert_to_complex():
a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128')+0.5j)
f = function([a],basic.convert_to_complex128(a))
#we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data))
f = function([b],basic.convert_to_complex128(b))
assert b.type.values_eq_approx(b.data, f(b.data))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(a.data, f(a.data))
#down cast don,t work for now
#f = function([b],basic.convert_to_complex64(b))
#assert b.type.values_eq_approx(b.data, f(b.data))
for nbits in (64, 128):
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic.convert_to_complex128(a))
assert a.type.values_eq_approx(b.data, f(a.data))
for t in ['int8','int16','int32','int64','float32']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
#this work, but should we allow it? How well it is implemented?
for t in ['float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
def test_default(): def test_default():
x, y = dscalars('xy') x, y = dscalars('xy')
z = default(x, y) z = default(x, y)
...@@ -1974,16 +1924,6 @@ def test_default_state(): ...@@ -1974,16 +1924,6 @@ def test_default_state():
f['x'] = None f['x'] = None
assert f(1) == 4.8 assert f(1) == 4.8
assert f(2.2) == 7 assert f(2.2) == 7
def test_bug_complext_10_august_09():
v0 = dmatrix()
v1 = basic.convert_to_complex128(v0)
inputs = [v0]
outputs = [v1]
f = function(inputs, outputs)
i = numpy.zeros((2,2))
assert (f(i)==numpy.zeros((2,2))).all()
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT': if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
......
import unittest import unittest
from theano import function from theano import function
from theano.tensor.basic import (_convert_to_int32, _convert_to_int8, _convert_to_int16,
_convert_to_int64, _convert_to_float32, _convert_to_float64)
from theano.tensor import * from theano.tensor import *
class test_casting(unittest.TestCase): class test_casting(unittest.TestCase):
def test_0(self): def test_0(self):
for op_fn in convert_to_int32, convert_to_float32, convert_to_float64: for op_fn in _convert_to_int32, _convert_to_float32, _convert_to_float64:
for type_fn in bvector, ivector, fvector, dvector: for type_fn in bvector, ivector, fvector, dvector:
x = type_fn() x = type_fn()
f = function([x], op_fn(x)) f = function([x], op_fn(x))
...@@ -17,9 +20,70 @@ class test_casting(unittest.TestCase): ...@@ -17,9 +20,70 @@ class test_casting(unittest.TestCase):
def test_illegal(self): def test_illegal(self):
try: try:
x = zmatrix() x = zmatrix()
function([x], convert_to_float64(x))(numpy.ones((2,3), dtype='complex128')) function([x], cast(x, 'float64'))(numpy.ones((2,3), dtype='complex128'))
except TypeError: except TypeError:
return return
assert 0 assert 0
def test_basic(self):
for type1 in ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']:
x = TensorType(dtype = type1, broadcastable = (False, )).make_variable()
for type2, converter in zip(['int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
[_convert_to_int8, _convert_to_int16,
_convert_to_int32, _convert_to_int64,
_convert_to_float32, _convert_to_float64]):
y = converter(x)
f = function([compile.In(x, strict = True)], y)
a = numpy.arange(10, dtype = type1)
b = f(a)
self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2)))
def test_convert_to_complex(self):
a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128')+0.5j)
f = function([a],basic._convert_to_complex128(a))
#we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data))
f = function([b],basic._convert_to_complex128(b))
assert b.type.values_eq_approx(b.data, f(b.data))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(a.data, f(a.data))
f = function([b],basic._convert_to_complex64(b))
assert b.type.values_eq_approx(a.data, f(b.data))
for nbits in (64, 128):
# upcasting to complex128
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic._convert_to_complex128(a))
assert a.type.values_eq_approx(b.data, f(a.data))
# upcasting to complex64
for t in ['int8','int16','int32','int64','float32']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
# downcast to complex64
for t in ['float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
def test_bug_complext_10_august_09(self):
v0 = dmatrix()
v1 = basic._convert_to_complex128(v0)
inputs = [v0]
outputs = [v1]
f = function(inputs, outputs)
i = numpy.zeros((2,2))
assert (f(i)==numpy.zeros((2,2))).all()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论