提交 bcf6b24f authored 作者: James Bergstra's avatar James Bergstra

added a few simple tests of complex number support

上级 028935f7
...@@ -1126,6 +1126,10 @@ class T_add(unittest.TestCase): ...@@ -1126,6 +1126,10 @@ class T_add(unittest.TestCase):
def test_grad_col(self): def test_grad_col(self):
utt.verify_grad(add, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)]) utt.verify_grad(add, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
class T_ceil(unittest.TestCase):
def test_complex(self):
self.assertRaises(TypeError, ceil, zvector())
class T_exp(unittest.TestCase): class T_exp(unittest.TestCase):
def test_grad_0(self): def test_grad_0(self):
utt.verify_grad(exp, [ utt.verify_grad(exp, [
...@@ -1136,6 +1140,19 @@ class T_exp(unittest.TestCase): ...@@ -1136,6 +1140,19 @@ class T_exp(unittest.TestCase):
numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ], numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ],
[ 2.04832468, 0.50791564, -1.58892269]])]) [ 2.04832468, 0.50791564, -1.58892269]])])
def test_int(self):
x = ivector()
f = function([x], exp(x))
exp_3 = f([3])
assert exp_3.dtype == 'float64'
def test_complex(self):
x = zvector()
assert exp(x).dtype == 'complex128'
f = function([x], exp(x))
exp_3 = f([3+2j])
assert numpy.allclose(exp_3, numpy.exp(3+2j))
class T_divimpl(unittest.TestCase): class T_divimpl(unittest.TestCase):
def test_impls(self): def test_impls(self):
i = iscalar() i = iscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论