提交 6ab3dddc authored 作者: Frederic's avatar Frederic

Add extra check for safety.

上级 0ff016bc
......@@ -1662,7 +1662,8 @@ class Pow(BinaryScalarOp):
if (node.inputs[0].type == node.outputs[0].type and
node.inputs[1].type == node.outputs[0].type and
# amdlibm 3.0 do not have a float64 version of this SIMD function
node.inputs[0].dtype == 'float32'):
node.inputs[0].dtype == 'float32' and
node.inputs[1].dtype == 'float32'):
dtype = 'float'
fct = "amd_vrsa_powf"
return """
......@@ -1674,10 +1675,11 @@ class Pow(BinaryScalarOp):
""" % locals()
# We compare the dtype and check we broadcast a scalar
elif (node.inputs[0].type == node.outputs[0].type and
node.inputs[1].dtype == node.outputs[0].dtype and
all(node.inputs[1].broadcastable) and
# amdlibm 3.0 do not have a float64 version of this SIMD function
node.inputs[0].dtype == 'float32'):
node.inputs[1].dtype == node.outputs[0].dtype and
all(node.inputs[1].broadcastable) and
# amdlibm 3.0 do not have a float64 version of this SIMD function
node.inputs[0].dtype == 'float32' and
node.inputs[1].dtype == 'float32'):
dtype = 'float'
fct = "amd_vrsa_powxf"
return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论