提交 5d44e7db authored 作者: James Bergstra's avatar James Bergstra

Modified scalar.Identity to provide casting functionality

上级 262fcf78
...@@ -805,12 +805,12 @@ second = Second(transfer_type(1), name = 'second') ...@@ -805,12 +805,12 @@ second = Second(transfer_type(1), name = 'second')
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x return getattr(numpy, self.output_types_preference.spec[0].dtype)(x)
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz, return [cast(gz, x.type.dtype)]
else: else:
return None, return None,
...@@ -841,6 +841,8 @@ def cast(t, dtype): ...@@ -841,6 +841,8 @@ def cast(t, dtype):
'float64': convert_to_float64, 'float64': convert_to_float64,
'complex64': convert_to_complex64, 'complex64': convert_to_complex64,
'complex128': convert_to_complex128} 'complex128': convert_to_complex128}
if t.type.dtype.startswith('complex') and not dtype.startswith('complex'):
raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()')
return mapping[dtype](t) return mapping[dtype](t)
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论