added care wrt non-int,non-float scalar types

上级 62c90bd2
......@@ -58,14 +58,15 @@ class Scalar(Result):
def dtype_specs(self):
try:
return {'float32': (float, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'float64': (float, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'int8': (int, 'npy_int8', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int16': (int, 'npy_int16', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int32': (int, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int64': (int, 'npy_int64', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'complex128': (complex, 'theano_complex128', 'PyComplex_Check', 'PyComplex_AsCComplex', 'PyComplex_FromCComplex'),
'complex64': (complex, 'theano_complex64', None, None, None)}[self.dtype]
return {'float32': (numpy.float32, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'float64': (numpy.float64, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'complex128': (numpy.complex128, 'theano_complex128', 'PyComplex_Check', 'PyComplex_AsCComplex', 'PyComplex_FromCComplex'),
'complex64': (numpy.complex64, 'theano_complex64', None, None, None),
'int8': (numpy.int8, 'npy_int8', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int16': (numpy.int16, 'npy_int16', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int32': (numpy.int32, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int64': (numpy.int64, 'npy_int64', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong')
}[self.dtype]
except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
......@@ -148,9 +149,7 @@ class Scalar(Result):
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
def __copy__(self):
"""
Return a copy of this instance (with its own attributes)
"""
"""Return a copy of this instance (with its own attributes)"""
cpy = self.__class__(self.dtype, self.name)
cpy.data = self.data
return cpy
......@@ -207,11 +206,16 @@ class ScalarOp(GuardedOp):
inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = [upcast(*i_dtypes)] * self.nout
o_dtypes = self.output_dtypes(*i_dtypes)
self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes]
def output_dtypes(self, *dtypes):
if self.nout != 1:
raise NotImplementedError()
return upcast(*dtypes),
def impl(self, *inputs):
raise AbstractFunctionError()
......@@ -232,6 +236,13 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp):
nin = 2
class FloatUnaryScalarOp(UnaryScalarOp):
def output_dtypes(self, input_dtype):
if 'int' in input_dtype: return 'float64',
if 'float' in input_dtype: return input_dtype,
raise NotImplementedError()
class Add(ScalarOp):
identity = 0
......@@ -318,6 +329,7 @@ class Neg(UnaryScalarOp):
return "%(z)s = -%(x)s;" % locals()
class Abs(UnaryScalarOp):
#TODO: for complex input, output is some flavour of float
def impl(self, x):
return numpy.abs(x)
def grad(self, (x, ), (gz, )):
......@@ -333,22 +345,24 @@ class Abs(UnaryScalarOp):
class Sgn(UnaryScalarOp):
def impl(self, x):
return numpy.abs(x) / x
#casting to output type is handled by filter
return 1.0 if x >= 0 else -1.0
def grad(self, (x, ), (gz, )):
return None,
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s/%(prefix)sabs(%(x)s);" \
% dict(locals(), prefix = 'float' in self.inputs[0].dtype and 'f' or '') # TODO: C use copysign
#casting is done by compiler
#TODO: use copysign
return "%(z)s = (%(x)s >= 0) ? 1.0 : -1.0;" % locals()
class Inv(UnaryScalarOp):
class Inv(FloatUnaryScalarOp):
def impl(self, x):
return 1 / x
return 1.0 / x
def grad(self, (x, ), (gz, )):
return -gz / (x * x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" % locals()
return "%(z)s = 1.0 / %(x)s;" % locals()
class Log(UnaryScalarOp):
class Log(FloatUnaryScalarOp):
def impl(self, x):
return math.log(x)
def grad(self, (x, ), (gz, )):
......@@ -356,7 +370,7 @@ class Log(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" % locals()
class Log2(UnaryScalarOp):
class Log2(FloatUnaryScalarOp):
def impl(self, x):
return numpy.log2(x)
def grad(self, (x, ), (gz, )):
......@@ -364,7 +378,7 @@ class Log2(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals()
class Exp(UnaryScalarOp):
class Exp(FloatUnaryScalarOp):
def impl(self, x):
return math.exp(x)
def grad(self, (x, ), (gz, )):
......@@ -380,7 +394,7 @@ class Sqr(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals()
class Sqrt(UnaryScalarOp):
class Sqrt(FloatUnaryScalarOp):
def impl(self, x):
return math.sqrt(x)
def grad(self, (x, ), (gz, )):
......@@ -388,7 +402,7 @@ class Sqrt(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals()
class Cos(UnaryScalarOp):
class Cos(FloatUnaryScalarOp):
def impl(self, x):
return math.cos(x)
def grad(self, (x, ), (gz, )):
......@@ -396,7 +410,7 @@ class Cos(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals()
class Sin(UnaryScalarOp):
class Sin(FloatUnaryScalarOp):
def impl(self, x):
return math.sin(x)
def grad(self, (x, ), (gz, )):
......@@ -404,15 +418,15 @@ class Sin(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals()
class Tan(UnaryScalarOp):
class Tan(FloatUnaryScalarOp):
def impl(self, x):
return math.tan(x)
def grad(self, (x, ), (gz, )):
raise NotImplementedError('lazy')
raise NotImplementedError()
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals()
class Cosh(UnaryScalarOp):
class Cosh(FloatUnaryScalarOp):
def impl(self, x):
return math.cosh(x)
def grad(self, (x, ), (gz, )):
......@@ -420,15 +434,15 @@ class Cosh(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals()
class Sinh(UnaryScalarOp):
class Sinh(FloatUnaryScalarOp):
def impl(self, x):
return math.sin(x)
return math.sinh(x)
def grad(self, (x, ), (gz, )):
return -gz * cos(x),
raise NotImplementedError()
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals()
class Tanh(UnaryScalarOp):
class Tanh(FloatUnaryScalarOp):
def impl(self, x):
return math.tanh(x)
def grad(self, (x, ), (gz, )):
......@@ -496,9 +510,6 @@ def composite(inputs, outputs):
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
_c_code += op.c_code([subd[input] for input in op.inputs],
[subd[output] for output in op.outputs],
......@@ -529,8 +540,10 @@ def composite(inputs, outputs):
nin = len(inputs)
nout = len(outputs)
# todo: propagate_dtypes?
def output_dtypes(self, *input_dtypes):
assert input_dtypes == tuple([input.dtype for input in inputs])
return [output.dtype for dtype in outputs]
def perform(self):
inputs = [input.data for input in self.inputs]
for output, impl in zip(self.outputs, _impls):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论