Unverified 提交 61b514a0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6571 from abergeron/fix_broken_reduce

Fix non-cudnn reductions that were all subtly broken by 91bc16c3.
......@@ -3729,27 +3729,26 @@ def local_dnn_reduction(node):
if node.inputs[0].ndim > 8:
return
acc_dtype = node.op._acc_dtype(node.inputs[0].dtype)
if node.inputs[0].dtype != node.outputs[0].dtype:
# We can mix float16 and float32, but not float64.
if (node.inputs[0].dtype == 'float64' or
node.outputs[0].dtype == 'float64'):
return
if node.op.acc_dtype != 'float32':
if acc_dtype != 'float32':
return
if node.inputs[0].dtype not in ['float16', 'float32', 'float64']:
return
if (node.inputs[0].dtype == 'float64' and
node.op.acc_dtype != 'float64'):
if (node.inputs[0].dtype == 'float64' and acc_dtype != 'float64'):
return
if (node.inputs[0].dtype == 'float32' and
node.op.acc_dtype != 'float32'):
if (node.inputs[0].dtype == 'float32' and acc_dtype != 'float32'):
return
if (node.inputs[0].dtype == 'float16' and
node.op.acc_dtype == 'float64'):
if (node.inputs[0].dtype == 'float16' and acc_dtype == 'float64'):
return
def _identity(a):
......@@ -3762,7 +3761,6 @@ def local_dnn_reduction(node):
post = _identity
if node.op.pre_scalar_op is not None:
# Might want to handle absmax, avg, and other cases for (norm1, norm2) here
if isinstance(node.op.scalar_op, theano.scalar.basic.Add):
if isinstance(node.op.pre_scalar_op, theano.scalar.basic.Sqr):
scal = 'norm2'
......@@ -3783,7 +3781,7 @@ def local_dnn_reduction(node):
with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal,
node.op.axis,
node.op.acc_dtype,
acc_dtype,
node.op.dtype,
False)(node.inputs[0])
return [post(ret)]
......
......@@ -1207,7 +1207,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
return False
x, = inputs
idtype = x.dtype
adtype = getattr(op, 'acc_dtype', idtype)
adtype = getattr(op, 'acc_dtype', None)
odtype = getattr(op, 'dtype', outputs[0].dtype)
# Force accumulator to float32 for float32 inputs since tree
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论