提交 a6ea5b1e authored 作者: Frederic Bastien's avatar Frederic Bastien

more fix to complex type cast and better test.

上级 af43180b
......@@ -175,11 +175,6 @@ class Scalar(Type):
"""
# todo: use C templating
return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast="""
complex_type& operator =(npy_float32 y) {
this->real=y;
this->imag=0;
return *this;
}
complex_type& operator =(theano_complex64 y) {
this->real=y.real;
this->imag=y.imag;
......
......@@ -1921,8 +1921,8 @@ def test_sum_overflow():
assert f([1]*300) == 300
def test_convert_to_complex():
a = value(numpy.ones(3, dtype='complex64'))#+0.5j)
b = value(numpy.ones(3, dtype='complex128'))#+0.5j)
a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128')+0.5j)
f = function([a],basic.convert_to_complex128(a))
#we need to compare with the same type.
......@@ -1935,23 +1935,16 @@ def test_convert_to_complex():
assert a.type.values_eq_approx(a.data, f(a.data))
#down cast don,t work for now
# f = function([b],basic.convert_to_complex64(b))
# assert b.type.values_eq_approx(b.data, f(b.data))
#f = function([b],basic.convert_to_complex64(b))
#assert b.type.values_eq_approx(b.data, f(b.data))
for nbits in (64, 128):
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic.convert_to_complex128(a))
self.failUnless(a.type.values_eq_approx(a.data, f(a.data)))
self.failUnless(b.type.values_eq_approx(b.data, f(b.data)))
assert a.type.values_eq_approx(b.data, f(a.data))
x=Scalar(t)
y=basic.convert_to_complex128(v0)
f = function([x],y)
assert f(0)==0
assert f(1)==1
def test_bug_complext_10_august_09():
v0 = dmatrix()
v1 = basic.convert_to_complex128(v0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论