提交 a824fa00 authored 作者: James Bergstra's avatar James Bergstra

Separated casting functionality from Identity into a new Cast Op in

scalar/basic. There is also a cast function, which is used in the grad() of ops that might upcast their arguments, to downcast the corresponding gradients.
上级 a20abc22
...@@ -830,36 +830,45 @@ second = Second(transfer_type(1), name = 'second') ...@@ -830,36 +830,45 @@ second = Second(transfer_type(1), name = 'second')
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
def impl(self, x): def impl(self, input):
return getattr(numpy, self.output_types_preference.spec[0].dtype)(x) return input
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return [cast(gz, x.type.dtype)] return gz,
else: else:
return None, return None,
#backport
#return gz if x.type in grad_types else None,
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
#### CASTING OPERATIONS #### CASTING OPERATIONS
class Cast(UnaryScalarOp):
def __init__(self, o_type, name=None):
if not isinstance(o_type, Scalar):
raise TypeError(o_type)
super(Cast, self).__init__(specific_out(o_type), name=name)
self.o_type = o_type
self.ctor = getattr(numpy, o_type.dtype)
def impl(self, input):
return self.ctor(input)
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return [cast(gz, x.type.dtype)]
else:
return None,
convert_to_int8 = Cast(int8, name='convert_to_int8')
convert_to_int16 = Cast(int16, name='convert_to_int16')
convert_to_int32 = Cast(int32, name='convert_to_int32')
convert_to_int64 = Cast(int64, name='convert_to_int64')
convert_to_float32 = Cast(float32, name='convert_to_float32')
convert_to_float64 = Cast(float64, name='convert_to_float64')
convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
convert_to_int8 = Identity(specific_out(int8), name='convert_to_int8') _cast_mapping = {'int8': convert_to_int8,
convert_to_int16 = Identity(specific_out(int16), name='convert_to_int16')
convert_to_int32 = Identity(specific_out(int32), name='convert_to_int32')
convert_to_int64 = Identity(specific_out(int64), name='convert_to_int64')
convert_to_float32 = Identity(specific_out(float32), name='convert_to_float32')
convert_to_float64 = Identity(specific_out(float64), name='convert_to_float64')
convert_to_complex64 = Identity(specific_out(complex64), name='convert_to_complex64')
convert_to_complex128 = Identity(specific_out(complex128), name='convert_to_complex128')
def cast(t, dtype):
if t.type.dtype == dtype:
return t
"""symbolically cast `t` to a Scalar of type `dtype`."""
mapping = {'int8': convert_to_int8,
'int16': convert_to_int16, 'int16': convert_to_int16,
'int32': convert_to_int32, 'int32': convert_to_int32,
'int64': convert_to_int64, 'int64': convert_to_int64,
...@@ -867,9 +876,14 @@ def cast(t, dtype): ...@@ -867,9 +876,14 @@ def cast(t, dtype):
'float64': convert_to_float64, 'float64': convert_to_float64,
'complex64': convert_to_complex64, 'complex64': convert_to_complex64,
'complex128': convert_to_complex128} 'complex128': convert_to_complex128}
if t.type.dtype.startswith('complex') and not dtype.startswith('complex'): def cast(x, dtype):
"""Symbolically cast `x` to a Scalar of given `dtype`."""
_x = as_scalar(x)
if _x.type.dtype == dtype:
return _x
if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()') raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()')
return mapping[dtype](t) return _cast_mapping[dtype](_x)
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
def make_node(self, x): def make_node(self, x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论