added dtypes to scalar

上级 250d45dd
......@@ -450,7 +450,6 @@ class T_mul(unittest.TestCase):
b = astensor(0.0)
check_eq2_both(self, [a,b], mul(a,b), [3.0, 4.0], 12.0)
check_eq2_both(self, [a,b], mul(b,a), [-1.0,2.0], -2.0)
#self.failUnless(isinstance(mul(a,b).owner, Scale))
a = astensor(numpy.ones(2))
b = astensor(numpy.ones(2))
......@@ -458,7 +457,6 @@ class T_mul(unittest.TestCase):
bb = numpy.asarray([-0.5, 2.0])
check_eq2_both(self, [a,b], mul(a,b), [aa,bb], numpy.asarray([0.25, 8.0]))
check_eq2_both(self, [a,b], mul(a,b), [bb,aa], numpy.asarray([0.25, 8.0]))
#self.failUnless(isinstance(mul(a,b).owner, MulElemwise))
def test_scalar(self):
r = numpy.random.rand(2,3)
......@@ -490,16 +488,6 @@ class T_mul(unittest.TestCase):
def test_grad_col(self):
verify_grad(self, Mul, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
# def test_operator(self):
# a = astensor([1,1])
# aa = astensor([1,1])
# b = astensor(4)
# self.failUnless(isinstance((a*b).owner, Scale))
# self.failUnless(isinstance((b*a).owner, Scale))
# self.failUnless(isinstance((a*aa).owner, MulElemwise))
# self.failUnless(isinstance((aa*a).owner, MulElemwise))
def test_wrong_shapes(self):
a = astensor(numpy.ones(3))
b = astensor(numpy.ones(4))
......@@ -734,7 +722,6 @@ class t_gemm(unittest.TestCase):
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
t_gemm = None
if __name__ == '__main__':
unittest.main()
......@@ -25,6 +25,7 @@ class Scalar(ResultBase):
def __init__(self, dtype, name = None):
ResultBase.__init__(self, role = None, name = name)
self.dtype = dtype
self.dtype_specs()
def __get_constant(self):
return self._constant
......@@ -49,8 +50,17 @@ class Scalar(ResultBase):
# and self.data == other.data
def dtype_specs(self):
return {'float64': (float, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'int32': (int, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong')}[self.dtype]
try:
return {'float32': (float, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'float64': (float, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'int8': (int, 'npy_int8', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int16': (int, 'npy_int16', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int32': (int, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int64': (int, 'npy_int64', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'complex128': (complex, 'theano_complex128', 'PyComplex_Check', 'PyComplex_AsCComplex', 'PyComplex_FromCComplex'),
'complex64': (complex, 'theano_complex64', None, None, None)}[self.dtype]
except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
def c_declare(self, name, sub):
return """
......@@ -89,6 +99,42 @@ class Scalar(ResultBase):
def c_cleanup(self, name, sub):
return ""
def c_support_code(cls):
template = """
struct theano_complex%(nbits)s : public npy_complex%(nbits)s
{
typedef theano_complex%(nbits)s complex_type;
typedef npy_float%(half_nbits)s scalar_type;
complex_type operator +(complex_type y) {
complex_type ret;
ret.real = this->real + y.real;
ret.imag = this->imag + y.imag;
return ret;
}
complex_type operator -(complex_type y) {
complex_type ret;
ret.real = this->real - y.real;
ret.imag = this->imag - y.imag;
return ret;
}
complex_type operator *(complex_type y) {
complex_type ret;
ret.real = this->real * y.real - this->imag * y.imag;
ret.imag = this->real * y.imag + this->imag * y.real;
return ret;
}
complex_type operator /(complex_type y) {
complex_type ret;
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
return ret;
}
};
"""
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
def __copy__(self):
"""
Return a copy of this instance (with its own attributes)
......@@ -128,15 +174,6 @@ class ScalarMixedOp(GuardedOp):
def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
# def c_var_names(self):
# (self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
# inames = utils.from_return_values(inames)
# onames = utils.from_return_values(onames)
# return [inames, onames]
# def c_code(self):
# return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论