提交 9c91657a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix exceptions in scalar.

上级 a116149c
...@@ -39,7 +39,7 @@ builtin_int = int ...@@ -39,7 +39,7 @@ builtin_int = int
builtin_float = float builtin_float = float
class ComplexError(Exception): class ComplexError(NotImplementedError):
""" """
Raised if complex numbers are used in an unsupported operation. Raised if complex numbers are used in an unsupported operation.
...@@ -2197,7 +2197,7 @@ class Sgn(UnaryScalarOp): ...@@ -2197,7 +2197,7 @@ class Sgn(UnaryScalarOp):
return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals() return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals()
if type in int_types: if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals() return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() # complex has no sgn raise ComplexError('complex has no sgn')
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version() s = super(Sgn, self).c_code_cache_version()
...@@ -2299,8 +2299,8 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -2299,8 +2299,8 @@ class RoundHalfToEven(UnaryScalarOp):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
typ = node.outputs[0].type.dtype typ = node.outputs[0].type.dtype
if typ not in ['float32', 'float64']: if not typ.startswith('float'):
Exception("The output should be float32 or float64") raise NotImplementedError("The output should a float")
return dedent(""" return dedent("""
#ifndef ROUNDING_EPSILON #ifndef ROUNDING_EPSILON
...@@ -2398,7 +2398,7 @@ class RoundHalfAwayFromZero(UnaryScalarOp): ...@@ -2398,7 +2398,7 @@ class RoundHalfAwayFromZero(UnaryScalarOp):
if node.outputs[0].type.dtype in ['float32', 'float64']: if node.outputs[0].type.dtype in ['float32', 'float64']:
return "%(z)s = round(%(x)s);" % locals() return "%(z)s = round(%(x)s);" % locals()
else: else:
Exception("The output should be float32 or float64") raise NotImplementedError("The output should be a float")
round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only) round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论