提交 1a1d1a02 authored 作者: Frederic's avatar Frederic

pep8

上级 4c90eecc
...@@ -5,30 +5,31 @@ from theano.tests import unittest_tools as utt ...@@ -5,30 +5,31 @@ from theano.tests import unittest_tools as utt
from numpy.testing import dec from numpy.testing import dec
class TestRealImag(unittest.TestCase): class TestRealImag(unittest.TestCase):
def test0(self): def test0(self):
x= zvector() x = zvector()
rng = numpy.random.RandomState(23) rng = numpy.random.RandomState(23)
xval = numpy.asarray(list(numpy.complex(rng.randn(), rng.randn()) for i in xrange(10))) xval = numpy.asarray(list(numpy.complex(rng.randn(), rng.randn())
assert numpy.all( xval.real == theano.function([x], real(x))(xval)) for i in xrange(10)))
assert numpy.all( xval.imag == theano.function([x], imag(x))(xval)) assert numpy.all(xval.real == theano.function([x], real(x))(xval))
assert numpy.all(xval.imag == theano.function([x], imag(x))(xval))
def test_on_real_input(self): def test_on_real_input(self):
x= dvector() x = dvector()
rng = numpy.random.RandomState(23) rng = numpy.random.RandomState(23)
xval = rng.randn(10) xval = rng.randn(10)
numpy.all( 0 == theano.function([x], imag(x))(xval)) numpy.all(0 == theano.function([x], imag(x))(xval))
numpy.all( xval == theano.function([x], real(x))(xval)) numpy.all(xval == theano.function([x], real(x))(xval))
x= imatrix() x = imatrix()
xval = numpy.asarray(rng.randn(3,3)*100, dtype='int32') xval = numpy.asarray(rng.randn(3, 3) * 100, dtype='int32')
numpy.all( 0 == theano.function([x], imag(x))(xval)) numpy.all(0 == theano.function([x], imag(x))(xval))
numpy.all( xval == theano.function([x], real(x))(xval)) numpy.all(xval == theano.function([x], real(x))(xval))
def test_cast(self): def test_cast(self):
x= zvector() x = zvector()
self.assertRaises(TypeError, cast, x, 'int32') self.assertRaises(TypeError, cast, x, 'int32')
def test_complex(self): def test_complex(self):
...@@ -36,27 +37,27 @@ class TestRealImag(unittest.TestCase): ...@@ -36,27 +37,27 @@ class TestRealImag(unittest.TestCase):
m = fmatrix() m = fmatrix()
c = complex(m[0], m[1]) c = complex(m[0], m[1])
assert c.type == cvector assert c.type == cvector
r,i = [real(c), imag(c)] r, i = [real(c), imag(c)]
assert r.type == fvector assert r.type == fvector
assert i.type == fvector assert i.type == fvector
f = theano.function([m], [r,i] ) f = theano.function([m], [r, i])
mval = numpy.asarray(rng.randn(2,5), dtype='float32') mval = numpy.asarray(rng.randn(2, 5), dtype='float32')
rval, ival = f(mval) rval, ival = f(mval)
assert numpy.all(rval == mval[0]), (rval,mval[0]) assert numpy.all(rval == mval[0]), (rval, mval[0])
assert numpy.all(ival == mval[1]), (ival, mval[1]) assert numpy.all(ival == mval[1]), (ival, mval[1])
@dec.knownfailureif(True,"Complex grads not enabled, see #178") @dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_complex_grads(self): def test_complex_grads(self):
def f(m): def f(m):
c = complex(m[0], m[1]) c = complex(m[0], m[1])
return .5 * real(c) + .9 * imag(c) return .5 * real(c) + .9 * imag(c)
rng = numpy.random.RandomState(9333) rng = numpy.random.RandomState(9333)
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2, 5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
@dec.knownfailureif(True,"Complex grads not enabled, see #178") @dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_mul_mixed0(self): def test_mul_mixed0(self):
def f(a): def f(a):
...@@ -64,7 +65,7 @@ class TestRealImag(unittest.TestCase): ...@@ -64,7 +65,7 @@ class TestRealImag(unittest.TestCase):
return abs((ac)**2).sum() return abs((ac)**2).sum()
rng = numpy.random.RandomState(9333) rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5)) aval = numpy.asarray(rng.randn(2, 5))
try: try:
utt.verify_grad(f, [aval]) utt.verify_grad(f, [aval])
except utt.verify_grad.E_grad, e: except utt.verify_grad.E_grad, e:
...@@ -72,7 +73,7 @@ class TestRealImag(unittest.TestCase): ...@@ -72,7 +73,7 @@ class TestRealImag(unittest.TestCase):
print e.analytic_grad print e.analytic_grad
raise raise
@dec.knownfailureif(True,"Complex grads not enabled, see #178") @dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_mul_mixed1(self): def test_mul_mixed1(self):
def f(a): def f(a):
...@@ -80,22 +81,23 @@ class TestRealImag(unittest.TestCase): ...@@ -80,22 +81,23 @@ class TestRealImag(unittest.TestCase):
return abs(ac).sum() return abs(ac).sum()
rng = numpy.random.RandomState(9333) rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5)) aval = numpy.asarray(rng.randn(2, 5))
try: try:
utt.verify_grad(f, [aval]) utt.verify_grad(f, [aval])
except utt.verify_grad.E_grad, e: except utt.verify_grad.E_grad, e:
print e.num_grad.gf print e.num_grad.gf
print e.analytic_grad print e.analytic_grad
raise raise
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
@dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_mul_mixed(self): def test_mul_mixed(self):
def f(a,b): def f(a, b):
ac = complex(a[0], a[1]) ac = complex(a[0], a[1])
return abs((ac*b)**2).sum() return abs((ac*b)**2).sum()
rng = numpy.random.RandomState(9333) rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5)) aval = numpy.asarray(rng.randn(2, 5))
bval = rng.randn(5) bval = rng.randn(5)
try: try:
utt.verify_grad(f, [aval, bval]) utt.verify_grad(f, [aval, bval])
...@@ -104,22 +106,22 @@ class TestRealImag(unittest.TestCase): ...@@ -104,22 +106,22 @@ class TestRealImag(unittest.TestCase):
print e.analytic_grad print e.analytic_grad
raise raise
@dec.knownfailureif(True,"Complex grads not enabled, see #178") @dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_polar_grads(self): def test_polar_grads(self):
def f(m): def f(m):
c = complex_from_polar(abs(m[0]), m[1]) c = complex_from_polar(abs(m[0]), m[1])
return .5 * real(c) + .9 * imag(c) return .5 * real(c) + .9 * imag(c)
rng = numpy.random.RandomState(9333) rng = numpy.random.RandomState(9333)
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2, 5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
@dec.knownfailureif(True,"Complex grads not enabled, see #178") @dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_abs_grad(self): def test_abs_grad(self):
def f(m): def f(m):
c = complex(m[0], m[1]) c = complex(m[0], m[1])
return .5 * abs(c) return .5 * abs(c)
rng = numpy.random.RandomState(9333) rng = numpy.random.RandomState(9333)
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2, 5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论