提交 604d4d2b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2307 from nouiz/dnn_pool

CRASH/OPT fix (quick to review)
......@@ -343,17 +343,18 @@ def get_c_extract(r, name, sub):
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]):
# check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. THis check isn't
# broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU.
if hasattr(c.op, 'check_broadcast'):
if any([getattr(c.op, 'check_broadcast', True) for (c, _) in
r.clients]):
c_extract = r.type.c_extract(name, sub, True)
else:
try:
c_extract = r.type.c_extract(
name, sub, True,
check_broadcast=c.op.check_broadcast)
check_broadcast=False)
except TypeError, e:
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, False)
......@@ -366,8 +367,18 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage."""
c_extract = r.type.c_extract_out(name, sub,
getattr(r.owner.op, 'check_input', config.check_input))
# check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU.
check_input = getattr(r.owner.op, 'check_input', config.check_input)
if getattr(r.owner.op, 'check_broadcast', True):
c_extract = r.type.c_extract_out(name, sub, check_input)
else:
try:
c_extract = r.type.c_extract_out(name, sub, check_input,
check_broadcast=False)
except TypeError, e:
c_extract = r.type.c_extract_out(name, sub, check_input)
pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......
......@@ -88,6 +88,10 @@ class DnnBase(GpuOp):
"""
Creates a handle for cudnn and pulls in the cudnn libraries and headers.
"""
# dnn does not know about broadcasting, so we do not need to assert
# the input broadcasting pattern.
check_broadcast = False
def c_headers(self):
return ['cudnn.h', 'cudnn_helper.h']
......@@ -244,7 +248,6 @@ class GpuDnnConvDesc(GpuOp):
class GpuDnnConvBase(DnnBase):
__props__ = ()
check_broadcast = False
def c_support_code_struct(self, node, struct_id):
return """
......@@ -1274,8 +1277,7 @@ if True:
border_mode=border_mode, subsample=subsample,
direction_hint=direction_hint)]
# DISABLED as there is problems in the handling of borders
# @register_opt('cudnn')
@register_opt('cudnn')
@local_optimizer([GpuDownsampleFactorMax])
def local_pool_dnn(node):
if not dnn_available():
......
......@@ -178,7 +178,7 @@ def test_dnn_tag():
try:
f = theano.function(
[x],
max_pool_2d(x, ds=(2, 2)),
max_pool_2d(x, ds=(2, 2), ignore_border=True),
mode=mode_with_gpu.including("cudnn"))
except (AssertionError, RuntimeError), e:
assert not cuda.dnn.dnn_available()
......
......@@ -360,6 +360,24 @@ class CudaNdarrayType(Type):
#print sio.getvalue()
return sio.getvalue()
def c_extract_out(self, name, sub, check_input=True, check_broadcast=True):
""" To allow the hack to skip check_broadcast.
"""
return """
if (py_%(name)s == Py_None)
{
%(c_init_code)s
}
else
{
%(c_extract_code)s
}
""" % dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input,
check_broadcast))
def c_cleanup(self, name, sub):
return """
//std::cerr << "cleanup " << py_%(name)s << " " << %(name)s << "\\n";
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论