提交 7c81e166 authored 作者: notoraptor's avatar notoraptor

Try to fix error related to multiples instances of overloaded function "fabs" with float16.

上级 91bc16c3
...@@ -948,8 +948,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype): ...@@ -948,8 +948,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
assert isinstance(self.scalar_op, (scalar.Maximum, assert isinstance(self.scalar_op, (scalar.Maximum,
scalar.Minimum)) scalar.Minimum))
if self.pre_scalar_op: # TODO: multiple dtypes if self.pre_scalar_op: # TODO: multiple dtypes
# dtype = node.inputs[0].dtype dtype = self._acc_dtype(self.acc_dtype)
dtype = 'float32'
dummy_var = scalar.Scalar(dtype=dtype)() dummy_var = scalar.Scalar(dtype=dtype)()
...@@ -1719,6 +1718,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype): ...@@ -1719,6 +1718,7 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version()) version.extend(Scalar(dtype=i.type.dtype).c_code_cache_version())
version.extend(self.kernel_version(node)) version.extend(self.kernel_version(node))
version.extend(self._acc_dtype(self.acc_dtype))
if all(version): if all(version):
return tuple(version) return tuple(version)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论