提交 23bca43d authored 作者: lucasb-eyer's avatar lucasb-eyer 提交者: Frederic Bastien

Cleanup of cudnn pooling zero-batch code.

上级 5921fd10
......@@ -1709,19 +1709,10 @@ if (CudaNdarray_prep_output(&%(out)s, %(nd)s+2, %(out)s_dims) != 0)
}
// if input batch is empty, we return the empty output without calling cuDNN
// (which will fail on zero batch size)
if (CudaNdarray_DIMS(%(input)s)[0] == 0) {
cudaError_t err2 = cudaMemset((%(out)s)->devdata, 0,
CudaNdarray_SIZE(%(out)s) * sizeof(real));
if (err2 != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv could not fill the output with zeros: %%s",
cudaGetErrorString(err2));
%(fail)s
}
// Ideally, "return success" here, but we don't have a %%(done)s
} else {
// (which will fail on zero batch size).
// Ideally, "return success" here, but we don't have a %%(done)s, so just skip the call.
if (CudaNdarray_DIMS(%(input)s)[0] > 0) {
// Don't indent for keeping history
if (c_set_tensorNd(%(input)s, %(input_desc)s) != 0)
%(fail)s
......@@ -1748,7 +1739,7 @@ if (err != CUDNN_STATUS_SUCCESS) {
%(fail)s
}
}
} // Closes the batchdim > 0 check.
""" % dict(out=out, fail=sub['fail'],
name=name, input=inputs[0],
ws=ws, pad=pad, str=stride,
......@@ -1963,19 +1954,10 @@ if (CudaNdarray_prep_output(&%(output_grad)s,
}
// if input batch is empty, we return the empty output without calling cuDNN
// (which will fail on zero batch size)
if (CudaNdarray_DIMS(%(input)s)[0] == 0) {
cudaError_t err2 = cudaMemset((%(output)s)->devdata, 0,
CudaNdarray_SIZE(%(output)s) * sizeof(real));
if (err2 != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuDnnConv could not fill the output with zeros: %%s",
cudaGetErrorString(err2));
%(fail)s
}
// Ideally, "return success" here, but we don't have a %%(done)s, so do else.
} else {
// (which will fail on zero batch size).
// Ideally, "return success" here, but we don't have a %%(done)s, so just skip the call.
if (CudaNdarray_DIMS(%(input)s)[0] > 0) {
// Don't indent for keeping history
if (c_set_tensorNd(%(input)s, %(input_desc)s) != 0)
%(fail)s
......@@ -2031,7 +2013,7 @@ if (err%(name)s != CUDNN_STATUS_SUCCESS) {
%(fail)s
}
}
} // Closes the batchdim > 0 check.
""" % dict(output_grad=out_grad,
fail=sub['fail'], name=name,
input=inp, input_grad=inp_grad, output=out,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论