提交 00587f67 authored 作者: James Bergstra's avatar James Bergstra

merge

差异被折叠。
...@@ -3,6 +3,8 @@ import theano ...@@ -3,6 +3,8 @@ import theano
from theano.tensor import * from theano.tensor import *
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from numpy.testing import dec
class TestRealImag(unittest.TestCase): class TestRealImag(unittest.TestCase):
def test0(self): def test0(self):
...@@ -53,6 +55,54 @@ class TestRealImag(unittest.TestCase): ...@@ -53,6 +55,54 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
@dec.knownfailureif(True,"Complex grads not enabled")
def test_mul_mixed0(self):
def f(a):
ac = complex(a[0], a[1])
return abs((ac)**2).sum()
rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5))
try:
utt.verify_grad(f, [aval])
except utt.verify_grad.E_grad, e:
print e.num_grad.gf
print e.analytic_grad
raise
@dec.knownfailureif(True,"Complex grads not enabled")
def test_mul_mixed1(self):
def f(a):
ac = complex(a[0], a[1])
return abs(ac).sum()
rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5))
try:
utt.verify_grad(f, [aval])
except utt.verify_grad.E_grad, e:
print e.num_grad.gf
print e.analytic_grad
raise
@dec.knownfailureif(True,"Complex grads not enabled")
def test_mul_mixed(self):
def f(a,b):
ac = complex(a[0], a[1])
return abs((ac*b)**2).sum()
rng = numpy.random.RandomState(9333)
aval = numpy.asarray(rng.randn(2,5))
bval = rng.randn(5)
try:
utt.verify_grad(f, [aval, bval])
except utt.verify_grad.E_grad, e:
print e.num_grad.gf
print e.analytic_grad
raise
def test_polar_grads(self): def test_polar_grads(self):
def f(m): def f(m):
c = complex_from_polar(abs(m[0]), m[1]) c = complex_from_polar(abs(m[0]), m[1])
...@@ -62,7 +112,6 @@ class TestRealImag(unittest.TestCase): ...@@ -62,7 +112,6 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
def test_abs_grad(self): def test_abs_grad(self):
def f(m): def f(m):
c = complex(m[0], m[1]) c = complex(m[0], m[1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论