提交 a6ea5b1e authored 作者: Frederic Bastien's avatar Frederic Bastien

more fix to complex type cast and better test.

上级 af43180b
...@@ -175,11 +175,6 @@ class Scalar(Type): ...@@ -175,11 +175,6 @@ class Scalar(Type):
""" """
# todo: use C templating # todo: use C templating
return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast=""" return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast="""
complex_type& operator =(npy_float32 y) {
this->real=y;
this->imag=0;
return *this;
}
complex_type& operator =(theano_complex64 y) { complex_type& operator =(theano_complex64 y) {
this->real=y.real; this->real=y.real;
this->imag=y.imag; this->imag=y.imag;
......
...@@ -1921,8 +1921,8 @@ def test_sum_overflow(): ...@@ -1921,8 +1921,8 @@ def test_sum_overflow():
assert f([1]*300) == 300 assert f([1]*300) == 300
def test_convert_to_complex(): def test_convert_to_complex():
a = value(numpy.ones(3, dtype='complex64'))#+0.5j) a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128'))#+0.5j) b = value(numpy.ones(3, dtype='complex128')+0.5j)
f = function([a],basic.convert_to_complex128(a)) f = function([a],basic.convert_to_complex128(a))
#we need to compare with the same type. #we need to compare with the same type.
...@@ -1935,22 +1935,15 @@ def test_convert_to_complex(): ...@@ -1935,22 +1935,15 @@ def test_convert_to_complex():
assert a.type.values_eq_approx(a.data, f(a.data)) assert a.type.values_eq_approx(a.data, f(a.data))
#down cast don,t work for now #down cast don,t work for now
# f = function([b],basic.convert_to_complex64(b)) #f = function([b],basic.convert_to_complex64(b))
# assert b.type.values_eq_approx(b.data, f(b.data)) #assert b.type.values_eq_approx(b.data, f(b.data))
for nbits in (64, 128): for nbits in (64, 128):
for t in ['int8','int16','int32','int64','float32','float64']: for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t)) a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype=t)) b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic.convert_to_complex128(a)) f = function([a],basic.convert_to_complex128(a))
self.failUnless(a.type.values_eq_approx(a.data, f(a.data))) assert a.type.values_eq_approx(b.data, f(a.data))
self.failUnless(b.type.values_eq_approx(b.data, f(b.data)))
x=Scalar(t)
y=basic.convert_to_complex128(v0)
f = function([x],y)
assert f(0)==0
assert f(1)==1
def test_bug_complext_10_august_09(): def test_bug_complext_10_august_09():
v0 = dmatrix() v0 = dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论