提交 3f20f985 authored 作者: Frederic Bastien's avatar Frederic Bastien

Crash Fix some GpuElemwise with int*

上级 5dc75c1b
......@@ -2755,7 +2755,8 @@ class Log(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log((%(cast)s)%(x)s);" % locals()
log = Log(upgrade_to_float, name='log')
......@@ -2794,7 +2795,8 @@ class Log2(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log2(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log2((%(cast)s)%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name='log2')
......@@ -2833,7 +2835,8 @@ class Log10(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log10(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log10((%(cast)s)%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name='log10')
......@@ -2870,7 +2873,8 @@ class Log1p(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = log1p(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = log1p((%(cast)s)%(x)s);" % locals()
log1p = Log1p(upgrade_to_float, name='log1p')
......@@ -2905,7 +2909,8 @@ class Exp(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = exp(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = exp((%(cast)s)%(x)s);" % locals()
exp = Exp(upgrade_to_float, name='exp')
......@@ -2938,7 +2943,8 @@ class Exp2(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = exp2(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = exp2((%(cast)s)%(x)s);" % locals()
exp2 = Exp2(upgrade_to_float, name='exp2')
......@@ -2971,7 +2977,8 @@ class Expm1(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = expm1(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = expm1((%(cast)s)%(x)s);" % locals()
def c_code_cache_version(self):
return (5,)
......@@ -3033,7 +3040,8 @@ class Sqrt(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sqrt(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = sqrt((%(cast)s)%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name='sqrt')
......@@ -3134,7 +3142,8 @@ class Cos(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = cos(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = cos((%(cast)s)%(x)s);" % locals()
cos = Cos(upgrade_to_float, name='cos')
......@@ -3167,7 +3176,8 @@ class ArcCos(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = acos(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = acos((%(cast)s)%(x)s);" % locals()
arccos = ArcCos(upgrade_to_float, name='arccos')
......@@ -3202,7 +3212,8 @@ class Sin(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sin(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = sin((%(cast)s)%(x)s);" % locals()
sin = Sin(upgrade_to_float, name='sin')
......@@ -3235,7 +3246,8 @@ class ArcSin(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = asin(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = asin((%(cast)s)%(x)s);" % locals()
arcsin = ArcSin(upgrade_to_float, name='arcsin')
......@@ -3268,7 +3280,8 @@ class Tan(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = tan(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = tan((%(cast)s)%(x)s);" % locals()
tan = Tan(upgrade_to_float, name='tan')
......@@ -3301,7 +3314,8 @@ class ArcTan(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atan(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = atan((%(cast)s)%(x)s);" % locals()
arctan = ArcTan(upgrade_to_float, name='arctan')
......@@ -3383,7 +3397,8 @@ class Cosh(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = cosh(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = cosh((%(cast)s)%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name='cosh')
......@@ -3416,7 +3431,8 @@ class ArcCosh(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = acosh(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = acosh((%(cast)s)%(x)s);" % locals()
arccosh = ArcCosh(upgrade_to_float, name='arccosh')
......@@ -3453,7 +3469,8 @@ class Sinh(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = sinh(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = sinh((%(cast)s)%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name='sinh')
......@@ -3486,7 +3503,8 @@ class ArcSinh(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = asinh(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = asinh((%(cast)s)%(x)s);" % locals()
arcsinh = ArcSinh(upgrade_to_float, name='arcsinh')
......@@ -3524,7 +3542,8 @@ class Tanh(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = tanh(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = tanh((%(cast)s)%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name='tanh')
......@@ -3557,7 +3576,8 @@ class ArcTanh(UnaryScalarOp):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = atanh(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = atanh((%(cast)s)%(x)s);" % locals()
arctanh = ArcTanh(upgrade_to_float, name='arctanh')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论