提交 c3f03aab authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't duplicate the kernel code in GpuEye for float16.

上级 ed222a6d
......@@ -20,6 +20,7 @@ except ImportError:
pass
from .type import GpuArrayType
from .fp16_help import write_w
def as_gpuarray_variable(x):
......@@ -888,6 +889,8 @@ class GpuSplit(HideC, Split):
class GpuEye(GpuKernelBase, Op):
__props__ = ('dtype',)
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
......@@ -915,31 +918,15 @@ class GpuEye(GpuKernelBase, Op):
return [grad_undefined(self, i, inp[i])
for i in xrange(3)]
def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype
def __hash__(self):
return hash(self.dtype) ^ hash(type(self))
def gpu_kernels(self, node, name):
if self.dtype == 'float16':
code = """
KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m;
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = __float2half_rn(1);
}
}"""
else:
code = """
code = """
KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m;
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = 1;
a[i*m + i] = %(write_a)s(1);
}
}"""
code = code % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name)
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name, write_a=write_w(self.dtype))
return [Kernel(
code=code, name="k",
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论