提交 3f20f985 authored 作者: Frederic Bastien's avatar Frederic Bastien

Crash Fix some GpuElemwise with int*

上级 5dc75c1b
...@@ -2755,7 +2755,8 @@ class Log(UnaryScalarOp): ...@@ -2755,7 +2755,8 @@ class Log(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = log(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log((%(cast)s)%(x)s);" % locals()
log = Log(upgrade_to_float, name='log') log = Log(upgrade_to_float, name='log')
...@@ -2794,7 +2795,8 @@ class Log2(UnaryScalarOp): ...@@ -2794,7 +2795,8 @@ class Log2(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = log2(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log2((%(cast)s)%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name='log2') log2 = Log2(upgrade_to_float, name='log2')
...@@ -2833,7 +2835,8 @@ class Log10(UnaryScalarOp): ...@@ -2833,7 +2835,8 @@ class Log10(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = log10(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log10((%(cast)s)%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name='log10') log10 = Log10(upgrade_to_float, name='log10')
...@@ -2870,7 +2873,8 @@ class Log1p(UnaryScalarOp): ...@@ -2870,7 +2873,8 @@ class Log1p(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = log1p(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log1p((%(cast)s)%(x)s);" % locals()
log1p = Log1p(upgrade_to_float, name='log1p') log1p = Log1p(upgrade_to_float, name='log1p')
...@@ -2905,7 +2909,8 @@ class Exp(UnaryScalarOp): ...@@ -2905,7 +2909,8 @@ class Exp(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = exp(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = exp((%(cast)s)%(x)s);" % locals()
exp = Exp(upgrade_to_float, name='exp') exp = Exp(upgrade_to_float, name='exp')
...@@ -2938,7 +2943,8 @@ class Exp2(UnaryScalarOp): ...@@ -2938,7 +2943,8 @@ class Exp2(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = exp2(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = exp2((%(cast)s)%(x)s);" % locals()
exp2 = Exp2(upgrade_to_float, name='exp2') exp2 = Exp2(upgrade_to_float, name='exp2')
...@@ -2971,7 +2977,8 @@ class Expm1(UnaryScalarOp): ...@@ -2971,7 +2977,8 @@ class Expm1(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = expm1(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = expm1((%(cast)s)%(x)s);" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (5,)
...@@ -3033,7 +3040,8 @@ class Sqrt(UnaryScalarOp): ...@@ -3033,7 +3040,8 @@ class Sqrt(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = sqrt(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = sqrt((%(cast)s)%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name='sqrt') sqrt = Sqrt(upgrade_to_float, name='sqrt')
...@@ -3134,7 +3142,8 @@ class Cos(UnaryScalarOp): ...@@ -3134,7 +3142,8 @@ class Cos(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = cos(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = cos((%(cast)s)%(x)s);" % locals()
cos = Cos(upgrade_to_float, name='cos') cos = Cos(upgrade_to_float, name='cos')
...@@ -3167,7 +3176,8 @@ class ArcCos(UnaryScalarOp): ...@@ -3167,7 +3176,8 @@ class ArcCos(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = acos(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = acos((%(cast)s)%(x)s);" % locals()
arccos = ArcCos(upgrade_to_float, name='arccos') arccos = ArcCos(upgrade_to_float, name='arccos')
...@@ -3202,7 +3212,8 @@ class Sin(UnaryScalarOp): ...@@ -3202,7 +3212,8 @@ class Sin(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = sin(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = sin((%(cast)s)%(x)s);" % locals()
sin = Sin(upgrade_to_float, name='sin') sin = Sin(upgrade_to_float, name='sin')
...@@ -3235,7 +3246,8 @@ class ArcSin(UnaryScalarOp): ...@@ -3235,7 +3246,8 @@ class ArcSin(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = asin(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = asin((%(cast)s)%(x)s);" % locals()
arcsin = ArcSin(upgrade_to_float, name='arcsin') arcsin = ArcSin(upgrade_to_float, name='arcsin')
...@@ -3268,7 +3280,8 @@ class Tan(UnaryScalarOp): ...@@ -3268,7 +3280,8 @@ class Tan(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = tan(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = tan((%(cast)s)%(x)s);" % locals()
tan = Tan(upgrade_to_float, name='tan') tan = Tan(upgrade_to_float, name='tan')
...@@ -3301,7 +3314,8 @@ class ArcTan(UnaryScalarOp): ...@@ -3301,7 +3314,8 @@ class ArcTan(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = atan(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = atan((%(cast)s)%(x)s);" % locals()
arctan = ArcTan(upgrade_to_float, name='arctan') arctan = ArcTan(upgrade_to_float, name='arctan')
...@@ -3383,7 +3397,8 @@ class Cosh(UnaryScalarOp): ...@@ -3383,7 +3397,8 @@ class Cosh(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = cosh(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = cosh((%(cast)s)%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name='cosh') cosh = Cosh(upgrade_to_float, name='cosh')
...@@ -3416,7 +3431,8 @@ class ArcCosh(UnaryScalarOp): ...@@ -3416,7 +3431,8 @@ class ArcCosh(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = acosh(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = acosh((%(cast)s)%(x)s);" % locals()
arccosh = ArcCosh(upgrade_to_float, name='arccosh') arccosh = ArcCosh(upgrade_to_float, name='arccosh')
...@@ -3453,7 +3469,8 @@ class Sinh(UnaryScalarOp): ...@@ -3453,7 +3469,8 @@ class Sinh(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = sinh(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = sinh((%(cast)s)%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name='sinh') sinh = Sinh(upgrade_to_float, name='sinh')
...@@ -3486,7 +3503,8 @@ class ArcSinh(UnaryScalarOp): ...@@ -3486,7 +3503,8 @@ class ArcSinh(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = asinh(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = asinh((%(cast)s)%(x)s);" % locals()
arcsinh = ArcSinh(upgrade_to_float, name='arcsinh') arcsinh = ArcSinh(upgrade_to_float, name='arcsinh')
...@@ -3524,7 +3542,8 @@ class Tanh(UnaryScalarOp): ...@@ -3524,7 +3542,8 @@ class Tanh(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = tanh(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = tanh((%(cast)s)%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name='tanh') tanh = Tanh(upgrade_to_float, name='tanh')
...@@ -3557,7 +3576,8 @@ class ArcTanh(UnaryScalarOp): ...@@ -3557,7 +3576,8 @@ class ArcTanh(UnaryScalarOp):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = atanh(%(x)s);" % locals() cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = atanh((%(cast)s)%(x)s);" % locals()
arctanh = ArcTanh(upgrade_to_float, name='arctanh') arctanh = ArcTanh(upgrade_to_float, name='arctanh')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论