提交 a824fa00 authored 作者: James Bergstra's avatar James Bergstra

Separated casting functionality from Identity into a new Cast Op in

scalar/basic. There is also a cast function, which is used in the grad() of ops that might upcast their arguments, to downcast the corresponding gradients.
上级 a20abc22
...@@ -830,46 +830,60 @@ second = Second(transfer_type(1), name = 'second') ...@@ -830,46 +830,60 @@ second = Second(transfer_type(1), name = 'second')
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
def impl(self, x): def impl(self, input):
return getattr(numpy, self.output_types_preference.spec[0].dtype)(x) return input
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return [cast(gz, x.type.dtype)] return gz,
else: else:
return None, return None,
#backport
#return gz if x.type in grad_types else None,
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
#### CASTING OPERATIONS #### CASTING OPERATIONS
class Cast(UnaryScalarOp):
def __init__(self, o_type, name=None):
if not isinstance(o_type, Scalar):
raise TypeError(o_type)
super(Cast, self).__init__(specific_out(o_type), name=name)
self.o_type = o_type
self.ctor = getattr(numpy, o_type.dtype)
def impl(self, input):
return self.ctor(input)
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return [cast(gz, x.type.dtype)]
else:
return None,
convert_to_int8 = Identity(specific_out(int8), name='convert_to_int8') convert_to_int8 = Cast(int8, name='convert_to_int8')
convert_to_int16 = Identity(specific_out(int16), name='convert_to_int16') convert_to_int16 = Cast(int16, name='convert_to_int16')
convert_to_int32 = Identity(specific_out(int32), name='convert_to_int32') convert_to_int32 = Cast(int32, name='convert_to_int32')
convert_to_int64 = Identity(specific_out(int64), name='convert_to_int64') convert_to_int64 = Cast(int64, name='convert_to_int64')
convert_to_float32 = Identity(specific_out(float32), name='convert_to_float32') convert_to_float32 = Cast(float32, name='convert_to_float32')
convert_to_float64 = Identity(specific_out(float64), name='convert_to_float64') convert_to_float64 = Cast(float64, name='convert_to_float64')
convert_to_complex64 = Identity(specific_out(complex64), name='convert_to_complex64') convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128 = Identity(specific_out(complex128), name='convert_to_complex128') convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
def cast(t, dtype): _cast_mapping = {'int8': convert_to_int8,
if t.type.dtype == dtype: 'int16': convert_to_int16,
return t 'int32': convert_to_int32,
"""symbolically cast `t` to a Scalar of type `dtype`.""" 'int64': convert_to_int64,
mapping = {'int8': convert_to_int8, 'float32': convert_to_float32,
'int16': convert_to_int16, 'float64': convert_to_float64,
'int32': convert_to_int32, 'complex64': convert_to_complex64,
'int64': convert_to_int64, 'complex128': convert_to_complex128}
'float32': convert_to_float32, def cast(x, dtype):
'float64': convert_to_float64, """Symbolically cast `x` to a Scalar of given `dtype`."""
'complex64': convert_to_complex64, _x = as_scalar(x)
'complex128': convert_to_complex128} if _x.type.dtype == dtype:
if t.type.dtype.startswith('complex') and not dtype.startswith('complex'): return _x
if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()') raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()')
return mapping[dtype](t) return _cast_mapping[dtype](_x)
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
def make_node(self, x): def make_node(self, x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论