提交 7c81e166 authored 作者: notoraptor's avatar notoraptor

Try to fix error related to multiples instances of overloaded function "fabs" with float16.

上级 91bc16c3
......@@ -948,8 +948,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
assert isinstance(self.scalar_op, (scalar.Maximum,
scalar.Minimum))
if self.pre_scalar_op: # TODO: multiple dtypes
# dtype = node.inputs[0].dtype
dtype = 'float32'
dtype = self._acc_dtype(self.acc_dtype)
dummy_var = scalar.Scalar(dtype=dtype)()
......@@ -1719,6 +1718,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
version.extend(self.kernel_version(node))
version.extend(self._acc_dtype(self.acc_dtype))
if all(version):
return tuple(version)
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论