提交 977b8b75 authored 作者: Frederic's avatar Frederic

pep8

上级 1f61b6c9
import unittest import unittest
from theano import function from theano import function
from theano.tensor.basic import (_convert_to_int32, _convert_to_int8, _convert_to_int16, from theano.tensor.basic import (_convert_to_int32, _convert_to_int8,
_convert_to_int64, _convert_to_float32, _convert_to_float64) _convert_to_int16, _convert_to_int64,
_convert_to_float32, _convert_to_float64)
from theano.tensor import * from theano.tensor import *
class test_casting(unittest.TestCase): class test_casting(unittest.TestCase):
def test_0(self): def test_0(self):
for op_fn in _convert_to_int32, _convert_to_float32, _convert_to_float64: for op_fn in [_convert_to_int32, _convert_to_float32,
_convert_to_float64]:
for type_fn in bvector, ivector, fvector, dvector: for type_fn in bvector, ivector, fvector, dvector:
x = type_fn() x = type_fn()
f = function([x], op_fn(x)) f = function([x], op_fn(x))
xval = theano._asarray(numpy.random.rand(10)*10, dtype=type_fn.dtype) xval = theano._asarray(numpy.random.rand(10) * 10,
dtype=type_fn.dtype)
yval = f(xval) yval = f(xval)
assert str(yval.dtype) == op_fn.scalar_op.output_types_preference.spec[0].dtype assert (str(yval.dtype) ==
op_fn.scalar_op.output_types_preference.spec[0].dtype)
def test_illegal(self): def test_illegal(self):
try: try:
x = zmatrix() x = zmatrix()
function([x], cast(x, 'float64'))(numpy.ones((2,3), dtype='complex128')) function([x], cast(x, 'float64'))(numpy.ones((2, 3),
dtype='complex128'))
except TypeError: except TypeError:
return return
assert 0 assert 0
def test_basic(self): def test_basic(self):
for type1 in ['uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64']: for type1 in ['uint8', 'uint16', 'uint32', 'uint64',
x = TensorType(dtype = type1, broadcastable = (False, )).make_variable() 'int8', 'int16', 'int32', 'int64', 'float32', 'float64']:
for type2, converter in zip(['int8', 'int16', 'int32', 'int64', 'float32', 'float64'], x = TensorType(dtype=type1,
broadcastable=(False, )).make_variable()
for type2, converter in zip(['int8', 'int16', 'int32', 'int64',
'float32', 'float64'],
[_convert_to_int8, _convert_to_int16, [_convert_to_int8, _convert_to_int16,
_convert_to_int32, _convert_to_int64, _convert_to_int32, _convert_to_int64,
_convert_to_float32, _convert_to_float64]): _convert_to_float32,
_convert_to_float64]):
y = converter(x) y = converter(x)
f = function([compile.In(x, strict = True)], y) f = function([compile.In(x, strict=True)], y)
a = numpy.arange(10, dtype = type1) a = numpy.arange(10, dtype=type1)
b = f(a) b = f(a)
self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2))) self.assertTrue(numpy.all(b == numpy.arange(10, dtype=type2)))
def test_convert_to_complex(self): def test_convert_to_complex(self):
val64 = numpy.ones(3, dtype='complex64') + 0.5j val64 = numpy.ones(3, dtype='complex64') + 0.5j
val128 = numpy.ones(3, dtype='complex128') + 0.5j val128 = numpy.ones(3, dtype='complex128') + 0.5j
vec64 = TensorType('complex64',(False,))() vec64 = TensorType('complex64', (False, ))()
vec128 = TensorType('complex128',(False,))() vec128 = TensorType('complex128', (False, ))()
f = function([vec64],basic._convert_to_complex128(vec64)) f = function([vec64], basic._convert_to_complex128(vec64))
#we need to compare with the same type. #we need to compare with the same type.
assert vec64.type.values_eq_approx(val128, f(val64)) assert vec64.type.values_eq_approx(val128, f(val64))
f = function([vec128],basic._convert_to_complex128(vec128)) f = function([vec128], basic._convert_to_complex128(vec128))
assert vec64.type.values_eq_approx(val128, f(val128)) assert vec64.type.values_eq_approx(val128, f(val128))
f = function([vec64],basic._convert_to_complex64(vec64)) f = function([vec64], basic._convert_to_complex64(vec64))
assert vec64.type.values_eq_approx(val64, f(val64)) assert vec64.type.values_eq_approx(val64, f(val64))
f = function([vec128],basic._convert_to_complex64(vec128)) f = function([vec128], basic._convert_to_complex64(vec128))
assert vec128.type.values_eq_approx(val64, f(val128)) assert vec128.type.values_eq_approx(val64, f(val128))
# upcasting to complex128 # upcasting to complex128
for t in ['int8','int16','int32','int64','float32','float64']: for t in ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']:
a = theano.shared(numpy.ones(3, dtype=t)) a = theano.shared(numpy.ones(3, dtype=t))
b = theano.shared(numpy.ones(3, dtype='complex128')) b = theano.shared(numpy.ones(3, dtype='complex128'))
f = function([],basic._convert_to_complex128(a)) f = function([], basic._convert_to_complex128(a))
assert a.type.values_eq_approx(b.get_value(), f()) assert a.type.values_eq_approx(b.get_value(), f())
# upcasting to complex64 # upcasting to complex64
for t in ['int8','int16','int32','int64','float32']: for t in ['int8', 'int16', 'int32', 'int64', 'float32']:
a = theano.shared(numpy.ones(3, dtype=t)) a = theano.shared(numpy.ones(3, dtype=t))
b = theano.shared(numpy.ones(3, dtype='complex64')) b = theano.shared(numpy.ones(3, dtype='complex64'))
f = function([],basic._convert_to_complex64(a)) f = function([], basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.get_value(), f()) assert a.type.values_eq_approx(b.get_value(), f())
# downcast to complex64 # downcast to complex64
for t in ['float64']: for t in ['float64']:
a = theano.shared(numpy.ones(3, dtype=t)) a = theano.shared(numpy.ones(3, dtype=t))
b = theano.shared(numpy.ones(3, dtype='complex64')) b = theano.shared(numpy.ones(3, dtype='complex64'))
f = function([],basic._convert_to_complex64(a)) f = function([], basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.get_value(), f()) assert a.type.values_eq_approx(b.get_value(), f())
def test_bug_complext_10_august_09(self): def test_bug_complext_10_august_09(self):
v0 = dmatrix() v0 = dmatrix()
v1 = basic._convert_to_complex128(v0) v1 = basic._convert_to_complex128(v0)
...@@ -87,5 +95,5 @@ class test_casting(unittest.TestCase): ...@@ -87,5 +95,5 @@ class test_casting(unittest.TestCase):
inputs = [v0] inputs = [v0]
outputs = [v1] outputs = [v1]
f = function(inputs, outputs) f = function(inputs, outputs)
i = numpy.zeros((2,2)) i = numpy.zeros((2, 2))
assert (f(i)==numpy.zeros((2,2))).all() assert (f(i) == numpy.zeros((2, 2))).all()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论