added care wrt non-int,non-float scalar types

上级 62c90bd2
...@@ -58,14 +58,15 @@ class Scalar(Result): ...@@ -58,14 +58,15 @@ class Scalar(Result):
def dtype_specs(self): def dtype_specs(self):
try: try:
return {'float32': (float, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'), return {'float32': (numpy.float32, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'float64': (float, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'), 'float64': (numpy.float64, 'npy_float64', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
'int8': (int, 'npy_int8', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'), 'complex128': (numpy.complex128, 'theano_complex128', 'PyComplex_Check', 'PyComplex_AsCComplex', 'PyComplex_FromCComplex'),
'int16': (int, 'npy_int16', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'), 'complex64': (numpy.complex64, 'theano_complex64', None, None, None),
'int32': (int, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'), 'int8': (numpy.int8, 'npy_int8', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'int64': (int, 'npy_int64', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'), 'int16': (numpy.int16, 'npy_int16', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'complex128': (complex, 'theano_complex128', 'PyComplex_Check', 'PyComplex_AsCComplex', 'PyComplex_FromCComplex'), 'int32': (numpy.int32, 'npy_int32', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong'),
'complex64': (complex, 'theano_complex64', None, None, None)}[self.dtype] 'int64': (numpy.int64, 'npy_int64', 'PyInt_Check', 'PyInt_AsLong', 'PyInt_FromLong')
}[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
...@@ -148,9 +149,7 @@ class Scalar(Result): ...@@ -148,9 +149,7 @@ class Scalar(Result):
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64) return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
def __copy__(self): def __copy__(self):
""" """Return a copy of this instance (with its own attributes)"""
Return a copy of this instance (with its own attributes)
"""
cpy = self.__class__(self.dtype, self.name) cpy = self.__class__(self.dtype, self.name)
cpy.data = self.data cpy.data = self.data
return cpy return cpy
...@@ -207,11 +206,16 @@ class ScalarOp(GuardedOp): ...@@ -207,11 +206,16 @@ class ScalarOp(GuardedOp):
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = [upcast(*i_dtypes)] * self.nout o_dtypes = self.output_dtypes(*i_dtypes)
self.inputs = inputs self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes] self.outputs = [Scalar(dtype) for dtype in o_dtypes]
def output_dtypes(self, *dtypes):
if self.nout != 1:
raise NotImplementedError()
return upcast(*dtypes),
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -232,6 +236,13 @@ class UnaryScalarOp(ScalarOp): ...@@ -232,6 +236,13 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
nin = 2 nin = 2
class FloatUnaryScalarOp(UnaryScalarOp):
def output_dtypes(self, input_dtype):
if 'int' in input_dtype: return 'float64',
if 'float' in input_dtype: return input_dtype,
raise NotImplementedError()
class Add(ScalarOp): class Add(ScalarOp):
identity = 0 identity = 0
...@@ -318,6 +329,7 @@ class Neg(UnaryScalarOp): ...@@ -318,6 +329,7 @@ class Neg(UnaryScalarOp):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
#TODO: for complex input, output is some flavour of float
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -333,22 +345,24 @@ class Abs(UnaryScalarOp): ...@@ -333,22 +345,24 @@ class Abs(UnaryScalarOp):
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.abs(x) / x #casting to output type is handled by filter
return 1.0 if x >= 0 else -1.0
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return None, return None,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s/%(prefix)sabs(%(x)s);" \ #casting is done by compiler
% dict(locals(), prefix = 'float' in self.inputs[0].dtype and 'f' or '') # TODO: C use copysign #TODO: use copysign
return "%(z)s = (%(x)s >= 0) ? 1.0 : -1.0;" % locals()
class Inv(UnaryScalarOp): class Inv(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1 / x return 1.0 / x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz / (x * x), return -gz / (x * x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" % locals() return "%(z)s = 1.0 / %(x)s;" % locals()
class Log(UnaryScalarOp): class Log(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -356,7 +370,7 @@ class Log(UnaryScalarOp): ...@@ -356,7 +370,7 @@ class Log(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" % locals() return "%(z)s = log(%(x)s);" % locals()
class Log2(UnaryScalarOp): class Log2(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -364,7 +378,7 @@ class Log2(UnaryScalarOp): ...@@ -364,7 +378,7 @@ class Log2(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals() return "%(z)s = log2(%(x)s);" % locals()
class Exp(UnaryScalarOp): class Exp(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -380,7 +394,7 @@ class Sqr(UnaryScalarOp): ...@@ -380,7 +394,7 @@ class Sqr(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
class Sqrt(UnaryScalarOp): class Sqrt(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sqrt(x) return math.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -388,7 +402,7 @@ class Sqrt(UnaryScalarOp): ...@@ -388,7 +402,7 @@ class Sqrt(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals() return "%(z)s = sqrt(%(x)s);" % locals()
class Cos(UnaryScalarOp): class Cos(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cos(x) return math.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -396,7 +410,7 @@ class Cos(UnaryScalarOp): ...@@ -396,7 +410,7 @@ class Cos(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals() return "%(z)s = cos(%(x)s);" % locals()
class Sin(UnaryScalarOp): class Sin(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return math.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -404,15 +418,15 @@ class Sin(UnaryScalarOp): ...@@ -404,15 +418,15 @@ class Sin(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
class Tan(UnaryScalarOp): class Tan(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tan(x) return math.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
raise NotImplementedError('lazy') raise NotImplementedError()
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals() return "%(z)s = tan(%(x)s);" % locals()
class Cosh(UnaryScalarOp): class Cosh(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cosh(x) return math.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -420,15 +434,15 @@ class Cosh(UnaryScalarOp): ...@@ -420,15 +434,15 @@ class Cosh(UnaryScalarOp):
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals() return "%(z)s = cosh(%(x)s);" % locals()
class Sinh(UnaryScalarOp): class Sinh(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return math.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz * cos(x), raise NotImplementedError()
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
class Tanh(UnaryScalarOp): class Tanh(FloatUnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tanh(x) return math.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -496,9 +510,6 @@ def composite(inputs, outputs): ...@@ -496,9 +510,6 @@ def composite(inputs, outputs):
i += 1 i += 1
name = "V%%(id)s_tmp%i" % i name = "V%%(id)s_tmp%i" % i
subd[output] = name subd[output] = name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code += "%s %s;\n" % (output.dtype_specs()[1], name) _c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
_c_code += op.c_code([subd[input] for input in op.inputs], _c_code += op.c_code([subd[input] for input in op.inputs],
[subd[output] for output in op.outputs], [subd[output] for output in op.outputs],
...@@ -529,7 +540,9 @@ def composite(inputs, outputs): ...@@ -529,7 +540,9 @@ def composite(inputs, outputs):
nin = len(inputs) nin = len(inputs)
nout = len(outputs) nout = len(outputs)
# todo: propagate_dtypes? def output_dtypes(self, *input_dtypes):
assert input_dtypes == tuple([input.dtype for input in inputs])
return [output.dtype for dtype in outputs]
def perform(self): def perform(self):
inputs = [input.data for input in self.inputs] inputs = [input.data for input in self.inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论