提交 ecd547fb authored 作者: --global's avatar --global

Document float16 issue with GpuElemwise

上级 a532d4e5
...@@ -274,6 +274,12 @@ def local_gpu_elemwise(node): ...@@ -274,6 +274,12 @@ def local_gpu_elemwise(node):
if isinstance(op.scalar_op, Pow): if isinstance(op.scalar_op, Pow):
old_out_dtype = node.outputs[0].dtype old_out_dtype = node.outputs[0].dtype
old_inp_dtypes = [inp.dtype for inp in node.inputs] old_inp_dtypes = [inp.dtype for inp in node.inputs]
# Upcast the input dtypes with 'float32' to obtain a floating-point
# dtype in which to do the computation.
# TODO : Currently, a bug in GpuElemwise prevents support for float16.
# It should be fixed and then the upcast below can use 'float16'
# instead of 'float32'
new_out_dtype = upcast("float32", *old_inp_dtypes) new_out_dtype = upcast("float32", *old_inp_dtypes)
# Transfer the inputs on the GPU and cast them to the right dtype # Transfer the inputs on the GPU and cast them to the right dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论