提交 a17abf5c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use the new clone_float32 in GpuElemwise and remove the remaining hacks.

上级 f35eefb8
...@@ -49,6 +49,11 @@ class GpuElemwise(HideC, Elemwise): ...@@ -49,6 +49,11 @@ class GpuElemwise(HideC, Elemwise):
nout = property(lambda self: self.scalar_op.nout) nout = property(lambda self: self.scalar_op.nout)
_f16_ok = True _f16_ok = True
def __init__(self, scalar_op, *args, **kwargs):
if isinstance(scalar_op, Composite):
scalar_op = scalar_op.clone_float32()
Elemwise.__init__(self, scalar_op, *args, **kwargs)
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -86,15 +91,7 @@ class GpuElemwise(HideC, Elemwise): ...@@ -86,15 +91,7 @@ class GpuElemwise(HideC, Elemwise):
except MethodNotDefined: except MethodNotDefined:
pass pass
if fake_node.op != self.scalar_op: node = Apply(self, inputs, outputs)
# If the new op is different due to type changes, we make a new
# op for it.
elem = GpuElemwise(fake_node.op, self.inplace_pattern, self.name,
self.nfunc_spec, self.openmp)
else:
elem = self
node = Apply(elem, inputs, outputs)
return node return node
def get_params(self, node): def get_params(self, node):
...@@ -119,8 +116,7 @@ class GpuElemwise(HideC, Elemwise): ...@@ -119,8 +116,7 @@ class GpuElemwise(HideC, Elemwise):
inps, outs, inps, outs,
dict(fail='return;')) dict(fail='return;'))
# Some ops like cast will reintroduce float16 in the internal graph. assert 'npy_float16' not in kop
kop = kop.replace('npy_float16', 'ga_float')
support_code = "" support_code = ""
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论