提交 dd0d4274 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix elemwise on float16 (hopefully).

上级 df1234f0
......@@ -33,6 +33,12 @@ def as_C_string_const(s):
for l in s.split('\n'))
def get_scal(dt):
if dt == 'float16':
dt = 'float32'
return scalar.get_scalar_type(dt)
class GpuElemwise(HideC, Elemwise):
"""
Elemwise on the GPU.
......@@ -60,23 +66,18 @@ class GpuElemwise(HideC, Elemwise):
zip(out_info[0], out_info[1])]
if len(outputs) > 1:
raise NotImplementedError()
node = Apply(self, inputs, outputs)
# Try to generate the kernel to catch SupportCodeErrors
scal_ins = [get_scal(i.dtype) for i in inputs]
fake_node = self.scalar_op.make_node(*[i() for i in scal_ins])
try:
scal_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
scal_out = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins],
[o() for o in scal_out])
code = self.scalar_op.c_support_code_apply(fake_node, "test")
code = fake_node.op.c_support_code_apply(fake_node, "test")
if code:
raise SupportCodeError(code)
except MethodNotDefined:
pass
try:
support_code = self.scalar_op.c_support_code()
support_code = fake_node.op.c_support_code()
if "struct" in support_code:
# The macro is fine, the C++ struct is not.
raise SupportCodeError(
......@@ -85,6 +86,15 @@ class GpuElemwise(HideC, Elemwise):
except MethodNotDefined:
pass
if fake_node.op != self.scalar_op:
# If the new op is different due to type changes, we make a new
# op for it.
elem = GpuElemwise(fake_node.op, self.inplace_pattern, self.name,
self.nfunc_spec, self.openmp)
else:
elem = self
node = Apply(elem, inputs, outputs)
return node
def get_params(self, node):
......@@ -92,59 +102,31 @@ class GpuElemwise(HideC, Elemwise):
def _get_vnames(self, node):
inps = ['i%d' % (n,) for n, _ in enumerate(node.inputs)]
outs = ['o%d' % (n,) for n, _ in enumerate(node.outputs) if n not in self.inplace_pattern]
outs = ['o%d' % (n,) if n not in self.inplace_pattern else
inps[self.inplace_pattern[n]]
for n, _ in enumerate(node.outputs)]
return inps, outs
def _generate_op_string(self, node):
scal_v_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
scal_v_outs = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
inps, outs = self._get_vnames(node)
scal_v_ins = [get_scal(i.dtype)() for i in node.inputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_v_ins],
[o() for o in scal_v_outs])
fake_node = self.scalar_op.make_node(*scal_v_ins)
scal_v_out = fake_node.outputs
assert len(scal_v_out) == len(node.outputs)
scal_in = [i if si.dtype != 'float16' else
'load_half(&' + i + ')' for i, si in zip(inps, scal_v_ins)]
kop = fake_node.op.c_code(fake_node, 'elem_scalar',
inps, outs,
dict(fail='return;'))
scal_out = []
oi = 0
scal_f16 = []
for n in range(len(node.outputs)):
if n in self.inplace_pattern:
arg = inps[self.inplace_pattern[n]]
else:
arg = outs[oi]
oi += 1
if node.outputs[n].dtype == 'float16':
scal_f16.append(('tmpf16%i' % (len(scal_f16),), arg))
scal_out.append(scal_f16[-1][0])
else:
scal_out.append(arg)
kop = self.scalar_op.c_code(fake_node, 'elem_scalar',
scal_in, scal_out,
dict(fail='return;'))
if scal_f16:
# if we have float16 scalars on output we have to wrap
# them and insert a stand-in float32 variable since
# float16 arithemtic is not available
code = ["{"]
for f in scal_f16:
code.append('ga_float %s;' % (f[0],))
# XXX: The replace is an ugly hack to make sure temp
# variables inthe middle are float32
code.append(kop.replace('npy_float16', 'ga_float'))
for f in scal_f16:
code.append('store_half(&%s, %s);' % (f[1], f[0]))
code.append('}')
kop = '\n'.join(code)
# Some ops like cast will reintroduce float16 in the internal graph.
kop = kop.replace('npy_float16', 'ga_float')
support_code = ""
try:
# We accept only some c_support_code().
# This filter is done in the make_node()
support_code += self.scalar_op.c_support_code()
support_code += fake_node.op.c_support_code()
except MethodNotDefined:
pass
for npy, ga in [("npy_uint8", "ga_ubyte"),
......@@ -171,7 +153,7 @@ class GpuElemwise(HideC, Elemwise):
def c_init_code_struct(self, node, name, sub):
inps, outs = self._get_vnames(node)
nargs = len(inps) + len(outs)
nargs = len(inps) + len(outs) - len(self.inplace_pattern)
support_code, kop = self._generate_op_string(node)
res = """
gpuelemwise_arg args[%(nargs)s] = {{0}};
......@@ -202,7 +184,7 @@ class GpuElemwise(HideC, Elemwise):
typecode=o.type.typecode)
res += """
ge = GpuElemwise_new(%(ctx)s->ctx, %(support)s, %(kop)s, %(nargs)s, args, %(nd)s, 0);
ge = GpuElemwise_new(%(ctx)s->ctx, %(support)s, %(kop)s, %(nargs)s, args, %(nd)s, GE_CONVERT_F16);
if (ge == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Could not initialize elemwise support");
%(fail)s
......@@ -363,7 +345,7 @@ class GpuElemwise(HideC, Elemwise):
def c_code_cache_version(self):
ver = self.scalar_op.c_code_cache_version()
if ver:
return (7, ver)
return (8, ver)
else:
return ver
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论