提交 8ff8eb32 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Handle acc_dtype correclty in GpuDnnReduction optimizer.

上级 d1407834
......@@ -3729,27 +3729,26 @@ def local_dnn_reduction(node):
if node.inputs[0].ndim > 8:
return
acc_dtype = node.op._acc_dtype(node.inputs[0].dtype)
if node.inputs[0].dtype != node.outputs[0].dtype:
# We can mix float16 and float32, but not float64.
if (node.inputs[0].dtype == 'float64' or
node.outputs[0].dtype == 'float64'):
return
if node.op.acc_dtype != 'float32':
if acc_dtype != 'float32':
return
if node.inputs[0].dtype not in ['float16', 'float32', 'float64']:
return
if (node.inputs[0].dtype == 'float64' and
node.op.acc_dtype != 'float64'):
if (node.inputs[0].dtype == 'float64' and acc_dtype != 'float64'):
return
if (node.inputs[0].dtype == 'float32' and
node.op.acc_dtype != 'float32'):
if (node.inputs[0].dtype == 'float32' and acc_dtype != 'float32'):
return
if (node.inputs[0].dtype == 'float16' and
node.op.acc_dtype == 'float64'):
if (node.inputs[0].dtype == 'float16' and acc_dtype == 'float64'):
return
def _identity(a):
......@@ -3762,7 +3761,6 @@ def local_dnn_reduction(node):
post = _identity
if node.op.pre_scalar_op is not None:
# Might want to handle absmax, avg, and other cases for (norm1, norm2) here
if isinstance(node.op.scalar_op, theano.scalar.basic.Add):
if isinstance(node.op.pre_scalar_op, theano.scalar.basic.Sqr):
scal = 'norm2'
......@@ -3783,7 +3781,7 @@ def local_dnn_reduction(node):
with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal,
node.op.axis,
node.op._acc_dtype(node.inputs[0].dtype),
acc_dtype,
node.op.dtype,
False)(node.inputs[0])
return [post(ret)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论