提交 6ab3dddc authored 作者: Frederic's avatar Frederic

Add extra check for safety.

上级 0ff016bc
...@@ -1662,7 +1662,8 @@ class Pow(BinaryScalarOp): ...@@ -1662,7 +1662,8 @@ class Pow(BinaryScalarOp):
if (node.inputs[0].type == node.outputs[0].type and if (node.inputs[0].type == node.outputs[0].type and
node.inputs[1].type == node.outputs[0].type and node.inputs[1].type == node.outputs[0].type and
# amdlibm 3.0 do not have a float64 version of this SIMD function # amdlibm 3.0 do not have a float64 version of this SIMD function
node.inputs[0].dtype == 'float32'): node.inputs[0].dtype == 'float32' and
node.inputs[1].dtype == 'float32'):
dtype = 'float' dtype = 'float'
fct = "amd_vrsa_powf" fct = "amd_vrsa_powf"
return """ return """
...@@ -1677,7 +1678,8 @@ class Pow(BinaryScalarOp): ...@@ -1677,7 +1678,8 @@ class Pow(BinaryScalarOp):
node.inputs[1].dtype == node.outputs[0].dtype and node.inputs[1].dtype == node.outputs[0].dtype and
all(node.inputs[1].broadcastable) and all(node.inputs[1].broadcastable) and
# amdlibm 3.0 do not have a float64 version of this SIMD function # amdlibm 3.0 do not have a float64 version of this SIMD function
node.inputs[0].dtype == 'float32'): node.inputs[0].dtype == 'float32' and
node.inputs[1].dtype == 'float32'):
dtype = 'float' dtype = 'float'
fct = "amd_vrsa_powxf" fct = "amd_vrsa_powxf"
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论