Unverified 提交 61b514a0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6571 from abergeron/fix_broken_reduce

Fix non-cudnn reductions that were all subtly broken by 91bc16c3.
...@@ -3729,27 +3729,26 @@ def local_dnn_reduction(node): ...@@ -3729,27 +3729,26 @@ def local_dnn_reduction(node):
if node.inputs[0].ndim > 8: if node.inputs[0].ndim > 8:
return return
acc_dtype = node.op._acc_dtype(node.inputs[0].dtype)
if node.inputs[0].dtype != node.outputs[0].dtype: if node.inputs[0].dtype != node.outputs[0].dtype:
# We can mix float16 and float32, but not float64. # We can mix float16 and float32, but not float64.
if (node.inputs[0].dtype == 'float64' or if (node.inputs[0].dtype == 'float64' or
node.outputs[0].dtype == 'float64'): node.outputs[0].dtype == 'float64'):
return return
if node.op.acc_dtype != 'float32': if acc_dtype != 'float32':
return return
if node.inputs[0].dtype not in ['float16', 'float32', 'float64']: if node.inputs[0].dtype not in ['float16', 'float32', 'float64']:
return return
if (node.inputs[0].dtype == 'float64' and if (node.inputs[0].dtype == 'float64' and acc_dtype != 'float64'):
node.op.acc_dtype != 'float64'):
return return
if (node.inputs[0].dtype == 'float32' and if (node.inputs[0].dtype == 'float32' and acc_dtype != 'float32'):
node.op.acc_dtype != 'float32'):
return return
if (node.inputs[0].dtype == 'float16' and if (node.inputs[0].dtype == 'float16' and acc_dtype == 'float64'):
node.op.acc_dtype == 'float64'):
return return
def _identity(a): def _identity(a):
...@@ -3762,7 +3761,6 @@ def local_dnn_reduction(node): ...@@ -3762,7 +3761,6 @@ def local_dnn_reduction(node):
post = _identity post = _identity
if node.op.pre_scalar_op is not None: if node.op.pre_scalar_op is not None:
# Might want to handle absmax, avg, and other cases for (norm1, norm2) here
if isinstance(node.op.scalar_op, theano.scalar.basic.Add): if isinstance(node.op.scalar_op, theano.scalar.basic.Add):
if isinstance(node.op.pre_scalar_op, theano.scalar.basic.Sqr): if isinstance(node.op.pre_scalar_op, theano.scalar.basic.Sqr):
scal = 'norm2' scal = 'norm2'
...@@ -3783,7 +3781,7 @@ def local_dnn_reduction(node): ...@@ -3783,7 +3781,7 @@ def local_dnn_reduction(node):
with inherit_stack_trace(node.outputs): with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal, ret = GpuDnnReduction(scal,
node.op.axis, node.op.axis,
node.op.acc_dtype, acc_dtype,
node.op.dtype, node.op.dtype,
False)(node.inputs[0]) False)(node.inputs[0])
return [post(ret)] return [post(ret)]
......
...@@ -1207,7 +1207,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs): ...@@ -1207,7 +1207,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
return False return False
x, = inputs x, = inputs
idtype = x.dtype idtype = x.dtype
adtype = getattr(op, 'acc_dtype', idtype) adtype = getattr(op, 'acc_dtype', None)
odtype = getattr(op, 'dtype', outputs[0].dtype) odtype = getattr(op, 'dtype', outputs[0].dtype)
# Force accumulator to float32 for float32 inputs since tree # Force accumulator to float32 for float32 inputs since tree
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论