提交 cc7c365c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix Elemwise for float16 inputs/outputs on GPU

上级 d3bba908
...@@ -102,31 +102,55 @@ class GpuElemwise(HideC, Elemwise): ...@@ -102,31 +102,55 @@ class GpuElemwise(HideC, Elemwise):
def generate_kernel(self, node, nodename): def generate_kernel(self, node, nodename):
inps = [make_argument(i, 'i%d' % (n,)) for n, i in inps = [make_argument(i, 'i%d' % (n,)) for n, i in
enumerate(node.inputs)] enumerate(node.inputs)]
scal_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs] scal_v_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
outs = [make_argument(o, 'o%d' % (n,)) for n, o in outs = [make_argument(o, 'o%d' % (n,)) for n, o in
enumerate(node.outputs) if not n in self.inplace_pattern] enumerate(node.outputs) if not n in self.inplace_pattern]
scal_out = [scalar.get_scalar_type(o.dtype) for o in node.outputs] scal_v_outs = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins], fake_node = Apply(self.scalar_op, [i() for i in scal_v_ins],
[o() for o in scal_out]) [o() for o in scal_v_outs])
scal_in = [i.name + '[i]' if i.dtype != 'float16' else
'__half2float(' + i.name + '[i])' for i in inps]
scal_out = [] scal_out = []
oi = 0 oi = 0
scal_f16 = []
for n in range(len(node.outputs)): for n in range(len(node.outputs)):
if n in self.inplace_pattern: if n in self.inplace_pattern:
scal_out.append(inps[self.inplace_pattern[n]].name+'[i]') arg = inps[self.inplace_pattern[n]]
else: else:
scal_out.append(outs[oi].name+'[i]') arg = outs[oi]
oi += 1 oi += 1
if arg.dtype == 'float16':
scal_f16.append(('tmpf16%i' % (len(scal_f16),), arg))
scal_out.append(scal_f16[-1][0])
else:
scal_out.append(arg.name + '[i]')
kop = self.scalar_op.c_code(fake_node, nodename+'_scalar', kop = self.scalar_op.c_code(fake_node, nodename+'_scalar',
[i.name+'[i]' for i in inps], scal_in, scal_out,
scal_out,
dict(fail='return;')) dict(fail='return;'))
if scal_f16:
# if we have float16 scalars on output we have to wrap
# them and insert a stand-in float32 variable since
# float16 arithemtic is not available
code = ["{"]
for f in scal_f16:
code.append('ga_float %s;' % (f[0],))
# XXX: The replace is an ugly hack to make sure temp
# variables inthe middle are float32
code.append(kop.replace('npy_uint16', 'ga_float'))
for f in scal_f16:
code.append('%s[i] = __float2half_rn(%s);' % (f[1].name, f[0]))
code.append('}')
kop = '\n'.join(code)
# Translate types for scalar composite ops (except complex). # Translate types for scalar composite ops (except complex).
# NB: OpenCL implicitly has 'stdint' defs at the kernel compilation stage # NB: OpenCL implicitly has 'stdint' defs at the kernel
# compilation stage
support_code = "" if pygpu.get_default_context().kind == 'opencl' else """ support_code = "" if pygpu.get_default_context().kind == 'opencl' else """
#ifdef _MSC_VER #ifdef _MSC_VER
#define signed __int8 int8_t #define signed __int8 int8_t
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论