提交 e422c2af authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Fix optimization to check if cudnn available

上级 fc4ebe25
......@@ -1596,7 +1596,11 @@ def local_gpua_lift_abstractconv_graph(op, context_name, inputs, outputs):
@op_lifter([pool.Pool])
@register_opt2([pool.Pool])
def local_gpu_pool(op, ctx_name, inputs, outputs):
from .dnn import dnn_available
assert op.__props__ == ('ignore_border', 'mode', 'ndim')
if op.ignore_border and dnn_available(ctx_name):
return
inp, ws, stride, pad = inputs
nd = op.ndim
if nd not in (2, 3):
......@@ -1617,7 +1621,11 @@ def local_gpu_pool(op, ctx_name, inputs, outputs):
@op_lifter([pool.MaxPoolGrad])
@register_opt2([pool.MaxPoolGrad])
def local_gpu_max_pool_grad(op, ctx_name, inputs, outputs):
from .dnn import dnn_available
assert op.__props__ == ('ignore_border', 'mode', 'ndim')
if op.ignore_border and dnn_available(ctx_name):
return
inp, out, out_grad, ws, stride, pad = inputs
nd = op.ndim
if nd not in (2, 3):
......@@ -1643,7 +1651,11 @@ def local_gpu_max_pool_grad(op, ctx_name, inputs, outputs):
@op_lifter([pool.AveragePoolGrad])
@register_opt2([pool.AveragePoolGrad])
def local_gpu_average_pool_grad(op, ctx_name, inputs, outputs):
from .dnn import dnn_available
assert op.__props__ == ('ignore_border', 'mode', 'ndim')
if op.ignore_border and dnn_available(ctx_name):
return
inp, out_grad, ws, stride, pad = inputs
nd = op.ndim
if nd not in (2, 3):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论