提交 a17abf5c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use the new clone_float32 in GpuElemwise and remove the remaining hacks.

上级 f35eefb8
......@@ -49,6 +49,11 @@ class GpuElemwise(HideC, Elemwise):
nout = property(lambda self: self.scalar_op.nout)
_f16_ok = True
def __init__(self, scalar_op, *args, **kwargs):
if isinstance(scalar_op, Composite):
scalar_op = scalar_op.clone_float32()
Elemwise.__init__(self, scalar_op, *args, **kwargs)
def __str__(self):
if self.name is not None:
return self.name
......@@ -86,15 +91,7 @@ class GpuElemwise(HideC, Elemwise):
except MethodNotDefined:
pass
if fake_node.op != self.scalar_op:
# If the new op is different due to type changes, we make a new
# op for it.
elem = GpuElemwise(fake_node.op, self.inplace_pattern, self.name,
self.nfunc_spec, self.openmp)
else:
elem = self
node = Apply(elem, inputs, outputs)
node = Apply(self, inputs, outputs)
return node
def get_params(self, node):
......@@ -119,8 +116,7 @@ class GpuElemwise(HideC, Elemwise):
inps, outs,
dict(fail='return;'))
# Some ops like cast will reintroduce float16 in the internal graph.
kop = kop.replace('npy_float16', 'ga_float')
assert 'npy_float16' not in kop
support_code = ""
try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论