提交 f1b31d35 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix local_gpua_careduce to return the good dtype for the Max reduction.

上级 b416e3b6
...@@ -1017,10 +1017,9 @@ def local_gpua_careduce(op, context_name, inputs, outputs): ...@@ -1017,10 +1017,9 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
else: else:
return False return False
x, = inputs x, = inputs
greduce = op2( greduce = op2(
op.scalar_op, axis=op.axis, op.scalar_op, axis=op.axis,
dtype=getattr(op, 'dtype', None), dtype=getattr(op, 'dtype', outputs[0].dtype),
acc_dtype=getattr(op, 'acc_dtype', None)) acc_dtype=getattr(op, 'acc_dtype', None))
gvar = greduce(x) gvar = greduce(x)
# We need to have the make node called, otherwise the mask can # We need to have the make node called, otherwise the mask can
...@@ -1059,7 +1058,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs): ...@@ -1059,7 +1058,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
greduce = op2( greduce = op2(
op.scalar_op, op.scalar_op,
axis=new_axis, reduce_mask=new_mask, axis=new_axis, reduce_mask=new_mask,
dtype=getattr(op, 'dtype', None), dtype=getattr(op, 'dtype', outputs[0].dtype),
acc_dtype=getattr(op, 'acc_dtype', None)) acc_dtype=getattr(op, 'acc_dtype', None))
reshaped_x = x.reshape(tensor.stack(new_in_shp)) reshaped_x = x.reshape(tensor.stack(new_in_shp))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论