提交 6df1cadc authored 作者: Frederic's avatar Frederic

Add support to gpu reduce for acc and output dtype.

上级 dcc8ea72
......@@ -311,9 +311,10 @@ def local_gpua_careduce(node):
if isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul,
scalar.Maximum, scalar.Minimum)):
x, = node.inputs
greduce = GpuCAReduceCuda(node.op.scalar_op, axis=node.op.axis)
if x.dtype != "float32":
return
greduce = GpuCAReduceCuda(
node.op.scalar_op, axis=node.op.axis,
dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None))
gvar = greduce(x)
#We need to have the make node called, otherwise the mask can
#be None
......
......@@ -68,9 +68,10 @@ class test_GpuCAReduceCPY(test_CAReduce):
class test_GpuCAReduceCuda(test_GpuCAReduceCPY):
dtypes = ["float32"]
dtypes = ["float32", "int64"]
dtypes = []
bin_dtypes = ["uint8", "int8"]
bin_dtypes = []
cases = [((5, 6), None),
((5, 6), (0, 1)),
((5, 6), (0, )),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论