提交 b08d5cb7 authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Remove useless gpu opt warning when dtype isn't float32

上级 ce0a3f02
......@@ -1116,6 +1116,8 @@ def local_gpu_incsubtensor(node):
incsubt = host_output.owner.op
x, y = host_output.owner.inputs[0:2]
coords = host_output.owner.inputs[2:]
if x.dtype != "float32" or y.dtype != "float32":
return
return [GpuIncSubtensor(
incsubt.idx_list,
inplace=incsubt.inplace,
......@@ -1126,7 +1128,7 @@ def local_gpu_incsubtensor(node):
# Incrementing a float32 x results in a float32
# output even if y is float64, so we can downcast
# y to put it on GPU
if type(node.op) == tensor.IncSubtensor and \
elif type(node.op) == tensor.IncSubtensor and \
node.inputs[0].dtype == "float32":
x, y = node.inputs[0:2]
assert isinstance(x.type, tensor.TensorType)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论