提交 af43180b authored 作者: Frederic Bastien's avatar Frederic Bastien

fixed the fix about complex type and added test for it. Downcast not allow for now.

上级 7ce5a2af
...@@ -165,22 +165,25 @@ class Scalar(Type): ...@@ -165,22 +165,25 @@ class Scalar(Type):
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret; return ret;
} }
complex_type operator =(scalar_type y) { complex_type& operator =(const scalar_type& y) {
complex_type ret; this->real=y;
ret.real=y; this->imag=0;
ret.imag=0; return *this;
return ret;
} }
%(upcast)s %(upcast)s
}; };
""" """
# todo: use C templating # todo: use C templating
return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast=""" return template % dict(nbits = 64, half_nbits = 32, upcast="") + template % dict(nbits = 128, half_nbits = 64, upcast="""
complex_type operator =(npy_float32 y) { complex_type& operator =(npy_float32 y) {
complex_type ret; this->real=y;
ret.real=y; this->imag=0;
ret.imag=0; return *this;
return ret; }
complex_type& operator =(theano_complex64 y) {
this->real=y.real;
this->imag=y.imag;
return *this;
} }
""") """)
......
...@@ -1920,20 +1920,47 @@ def test_sum_overflow(): ...@@ -1920,20 +1920,47 @@ def test_sum_overflow():
f = function([a], sum(a)) f = function([a], sum(a))
assert f([1]*300) == 300 assert f([1]*300) == 300
def test_convert_to_complex():
a = value(numpy.ones(3, dtype='complex64'))#+0.5j)
b = value(numpy.ones(3, dtype='complex128'))#+0.5j)
f = function([a],basic.convert_to_complex128(a))
#we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data))
f = function([b],basic.convert_to_complex128(b))
assert b.type.values_eq_approx(b.data, f(b.data))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(a.data, f(a.data))
#down cast don,t work for now
# f = function([b],basic.convert_to_complex64(b))
# assert b.type.values_eq_approx(b.data, f(b.data))
for nbits in (64, 128):
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype=t))
f = function([a],basic.convert_to_complex128(a))
self.failUnless(a.type.values_eq_approx(a.data, f(a.data)))
self.failUnless(b.type.values_eq_approx(b.data, f(b.data)))
x=Scalar(t)
y=basic.convert_to_complex128(v0)
f = function([x],y)
assert f(0)==0
assert f(1)==1
def test_bug_complext_10_august_09(): def test_bug_complext_10_august_09():
v0 = dmatrix() v0 = dmatrix()
v1 = dscalar() v1 = basic.convert_to_complex128(v0)
v2 = dvector()
v3 = dscalar()
v5 = sub(v3,v3)
v6 = basic.convert_to_complex128(v0)
# v7 = basic.convert_to_float64(v6)
inputs = [v0, v1, v2, v3]
outputs = [v5,v6]
function(inputs, outputs, mode=compile.Mode('py', 'fast_compile'))
function(inputs, outputs, mode=compile.debugmode.DebugMode())
inputs = [v0]
outputs = [v1]
f = function(inputs, outputs)
i = numpy.zeros((2,2))
assert (f(i)==numpy.zeros((2,2))).all()
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT': if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论