提交 55a1fe82 authored 作者: Frederic's avatar Frederic

hack to lower the number of compilation on the GPU.

上级 01187ffe
......@@ -342,8 +342,18 @@ def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]):
c_extract = r.type.c_extract(name, sub, True)
# check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. THis check isn't
# done in the new GPU back-end or on the CPU.
if hasattr(c.op, 'check_broadcast'):
try:
c_extract = r.type.c_extract(
name, sub, True,
check_broadcast=c.op.check_broadcast)
except TypeError, e:
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, False)
......
......@@ -244,6 +244,7 @@ class GpuDnnConvDesc(GpuOp):
class GpuDnnConvBase(DnnBase):
__props__ = ()
check_broadcast = False
def c_support_code_struct(self, node, struct_id):
return """
......
......@@ -280,7 +280,8 @@ class CudaNdarrayType(Type):
def c_init(self, name, sub):
return "%(name)s = NULL;" % locals()
def c_extract(self, name, sub, check_input=True):
def c_extract(self, name, sub, check_input=True,
check_broadcast=True):
sio = StringIO()
fail = sub['fail']
nd = self.ndim
......@@ -307,7 +308,7 @@ class CudaNdarrayType(Type):
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals()
for i, b in enumerate(self.broadcastable):
if b:
if b and check_broadcast:
print >> sio, """
if (CudaNdarray_HOST_DIMS(%(name)s)[%(i)s] != 1)
{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论