提交 2c1d697c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add float16 support for GpuEye.

上级 e6dc5e8d
......@@ -922,13 +922,24 @@ class GpuEye(GpuKernelBase, Op):
return hash(self.dtype) ^ hash(type(self))
def gpu_kernels(self, node, name):
code = """
if self.dtype == 'float16':
code = """
KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m;
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = __float2half_rn(1);
}
}"""
else:
code = """
KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m;
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = 1;
}
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype), name=name)
}"""
code = code % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name)
return [Kernel(
code=code, name="k",
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE],
......
......@@ -436,7 +436,7 @@ def test_gpueye():
assert any([isinstance(node.op, GpuEye)
for node in f.maker.fgraph.toposort()])
for dtype in ['float32', 'int32']:
for dtype in ['float32', 'int32', 'float16']:
yield check, dtype, 3
# M != N, k = 0
yield check, dtype, 3, 5
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论