提交 d43f6df9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Disable C code for all Elemwise that interact with float16

上级 9111dd3d
...@@ -1982,8 +1982,6 @@ class Cast(UnaryScalarOp): ...@@ -1982,8 +1982,6 @@ class Cast(UnaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
if node.inputs[0].dtype == 'float16' or node.outputs[0] == 'float16':
raise NotImplementedError("C code doesn't work for float16")
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x) return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, inputs, gout): def grad(self, inputs, gout):
...@@ -3302,14 +3300,19 @@ class Composite(ScalarOp): ...@@ -3302,14 +3300,19 @@ class Composite(ScalarOp):
+ zip(self.fgraph.outputs, + zip(self.fgraph.outputs,
["%%(o%i)s" % i for i in xrange(len(self.fgraph.outputs))])) ["%%(o%i)s" % i for i in xrange(len(self.fgraph.outputs))]))
for orphan in self.fgraph.variables: # fgraph.orphans: for var in self.fgraph.variables:
if orphan.owner is None and orphan not in self.fgraph.inputs: if var.owner is None and var not in self.fgraph.inputs:
if isinstance(orphan, Constant): # This is an orphan
subd[orphan] = orphan.type.c_literal(orphan.data) if isinstance(var, Constant):
subd[var] = var.type.c_literal(var.data)
else: else:
raise ValueError( raise ValueError(
"All orphans in the fgraph to Composite must" "All orphans in the fgraph to Composite must"
" be Constant instances.") " be Constant instances.")
elif any(i.dtype == 'float16' for i in var.owner.inputs or
o.dtype == 'float16' for o in var.owner.outputs):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n" _c_code = "{\n"
self.nodenames = ["%(nodename)s_" + ('subnode%i' % j) self.nodenames = ["%(nodename)s_" + ('subnode%i' % j)
......
...@@ -1171,6 +1171,12 @@ class Elemwise(OpenMPOp): ...@@ -1171,6 +1171,12 @@ class Elemwise(OpenMPOp):
return decl, checks, alloc, loop return decl, checks, alloc, loop
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
if any(i.dtype == 'float16' for i in node.inputs or
o.dtype == 'float16' for o in node.inputs or
# This is for Composite
getattr(self.scalar_op, 'inner_float16', False)):
# Disable C code for float16 vars
super(Elemwise, self).c_code(node, nodename, inames, onames, sub)
code = "\n".join(self._c_all(node, nodename, inames, onames, sub)) code = "\n".join(self._c_all(node, nodename, inames, onames, sub))
return code return code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论