提交 d43f6df9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Disable C code for all Elemwise that interact with float16

上级 9111dd3d
......@@ -1982,8 +1982,6 @@ class Cast(UnaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
if node.inputs[0].dtype == 'float16' or node.outputs[0] == 'float16':
raise NotImplementedError("C code doesn't work for float16")
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, inputs, gout):
......@@ -3302,14 +3300,19 @@ class Composite(ScalarOp):
+ zip(self.fgraph.outputs,
["%%(o%i)s" % i for i in xrange(len(self.fgraph.outputs))]))
for orphan in self.fgraph.variables: # fgraph.orphans:
if orphan.owner is None and orphan not in self.fgraph.inputs:
if isinstance(orphan, Constant):
subd[orphan] = orphan.type.c_literal(orphan.data)
for var in self.fgraph.variables:
if var.owner is None and var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant instances.")
elif any(i.dtype == 'float16' for i in var.owner.inputs or
o.dtype == 'float16' for o in var.owner.outputs):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = ["%(nodename)s_" + ('subnode%i' % j)
......
......@@ -1171,6 +1171,12 @@ class Elemwise(OpenMPOp):
return decl, checks, alloc, loop
def c_code(self, node, nodename, inames, onames, sub):
if any(i.dtype == 'float16' for i in node.inputs or
o.dtype == 'float16' for o in node.inputs or
# This is for Composite
getattr(self.scalar_op, 'inner_float16', False)):
# Disable C code for float16 vars
super(Elemwise, self).c_code(node, nodename, inames, onames, sub)
code = "\n".join(self._c_all(node, nodename, inames, onames, sub))
return code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论