提交 59553e0a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a blurb about how to deal with float16.

上级 0faeaab5
......@@ -156,25 +156,48 @@ in most cases in your :meth:`c_code` method. This is done by using
the provided wrapper function. An example calling the above kernel
would be::
size_t ls, gs;
size_t dims[2];
size_t n = 256;
// ...
ls = 1;
gs = 256;
err = k_call(1, &gs, &ls, 0, input->ga.data, dims[0], dims[1]);
err = k_scall(1, &n, 0, input->ga.data, dims[0], dims[1]);
// ...
If you want explicit control over the scheduling, you can use the
`_call` wrapper instead which works like this::
size_t ls, gs;
// ...
gs = 1;
ls = 256;
err = k_call(1, &gs, &ls, 0, input->ga.data, dims[0], dims[1]);
The name of the wrapper function depends on the name you passed to
``Kernel()`` when you declared it (or the name in your `#kernel`
statement). It defaults to `name + '_call'`.
statement). It defaults to `name + '_call' or '_scall'`.
For other operations in the C code you should refer to the
`libgpuarray documentation
<http://deeplearning.net/software/libgpuarray/>`_.
Dealing with float16
====================
To support limited-precision storage in a kernel you have to be
careful to load values properly, declare working memory in float32 and
write results properly. To help with that some functions have been
declared in `theano.gpuarray.fp16_help`.
To load the inputs you should wrap the read with the function returned
by :function:`load_w`. Similarly writes should be wrapped in the
function returned by :function:`write_w`. Finally working data should
have the type returned by :function:`work_dtype`.
A Complete Example
==================
......
......@@ -2,6 +2,10 @@ from __future__ import absolute_import, print_function, division
def work_dtype(dtype):
"""
Return the data type for working memory.
"""
if dtype == 'float16':
return 'float32'
else:
......@@ -9,6 +13,14 @@ def work_dtype(dtype):
def load_w(dtype):
"""
Return the function name to load data.
This should be used like this::
code = '%(load_f)s(ival)' % (load_w(input_type),)
"""
if dtype == 'float16':
return '__half2float'
else:
......@@ -16,6 +28,14 @@ def load_w(dtype):
def write_w(dtype):
"""
Return the function name to write data.
This should be used like this::
code = 'res = %(write_f)s(oval)' % (write_w(output_type),)
"""
if dtype == 'float16':
return '__float2half_rn'
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论