提交 55a1fe82 authored 作者: Frederic's avatar Frederic

hack to lower the number of compilation on the GPU.

上级 01187ffe
...@@ -342,7 +342,17 @@ def get_c_extract(r, name, sub): ...@@ -342,7 +342,17 @@ def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage.""" """Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]): r.clients]):
# check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. THis check isn't
# done in the new GPU back-end or on the CPU.
if hasattr(c.op, 'check_broadcast'):
try:
c_extract = r.type.c_extract(
name, sub, True,
check_broadcast=c.op.check_broadcast)
except TypeError, e:
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, True) c_extract = r.type.c_extract(name, sub, True)
else: else:
c_extract = r.type.c_extract(name, sub, False) c_extract = r.type.c_extract(name, sub, False)
......
...@@ -244,6 +244,7 @@ class GpuDnnConvDesc(GpuOp): ...@@ -244,6 +244,7 @@ class GpuDnnConvDesc(GpuOp):
class GpuDnnConvBase(DnnBase): class GpuDnnConvBase(DnnBase):
__props__ = () __props__ = ()
check_broadcast = False
def c_support_code_struct(self, node, struct_id): def c_support_code_struct(self, node, struct_id):
return """ return """
......
...@@ -280,7 +280,8 @@ class CudaNdarrayType(Type): ...@@ -280,7 +280,8 @@ class CudaNdarrayType(Type):
def c_init(self, name, sub): def c_init(self, name, sub):
return "%(name)s = NULL;" % locals() return "%(name)s = NULL;" % locals()
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True,
check_broadcast=True):
sio = StringIO() sio = StringIO()
fail = sub['fail'] fail = sub['fail']
nd = self.ndim nd = self.ndim
...@@ -307,7 +308,7 @@ class CudaNdarrayType(Type): ...@@ -307,7 +308,7 @@ class CudaNdarrayType(Type):
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n"; //std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals() """ % locals()
for i, b in enumerate(self.broadcastable): for i, b in enumerate(self.broadcastable):
if b: if b and check_broadcast:
print >> sio, """ print >> sio, """
if (CudaNdarray_HOST_DIMS(%(name)s)[%(i)s] != 1) if (CudaNdarray_HOST_DIMS(%(name)s)[%(i)s] != 1)
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论