提交 092835c1 authored 作者: Frederic's avatar Frederic

Better hack to skip broadcast check on the GPU.

上级 4c326058
...@@ -343,17 +343,18 @@ def get_c_extract(r, name, sub): ...@@ -343,17 +343,18 @@ def get_c_extract(r, name, sub):
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]): r.clients]):
# check_broadcast is just an hack to easily remove just the # check_broadcast is just an hack to easily remove just the
# broadcast check on the old GPU back-end. THis check isn't # broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU. # done in the new GPU back-end or on the CPU.
if hasattr(r.clients[-1][0].op, 'check_broadcast'): if any([getattr(c.op, 'check_broadcast', True) for (c, _) in
r.clients]):
c_extract = r.type.c_extract(name, sub, True)
else:
try: try:
c_extract = r.type.c_extract( c_extract = r.type.c_extract(
name, sub, True, name, sub, True,
check_broadcast=c.op.check_broadcast) check_broadcast=False)
except TypeError, e: except TypeError, e:
c_extract = r.type.c_extract(name, sub, True) c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, True)
else: else:
c_extract = r.type.c_extract(name, sub, False) c_extract = r.type.c_extract(name, sub, False)
...@@ -366,8 +367,19 @@ def get_c_extract(r, name, sub): ...@@ -366,8 +367,19 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub): def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage.""" """Wrapper around c_extract_out that initializes py_name from storage."""
c_extract = r.type.c_extract_out(name, sub, # check_broadcast is just an hack to easily remove just the
getattr(r.owner.op, 'check_input', config.check_input)) # broadcast check on the old GPU back-end. This check isn't
# done in the new GPU back-end or on the CPU.
check_input = getattr(r.owner.op, 'check_input', config.check_input)
if getattr(r.owner.op, 'check_broadcast', True):
c_extract = r.type.c_extract_out(name, sub, check_input)
else:
try:
c_extract = r.type.c_extract_out(name, sub, check_input,
check_broadcast=False)
except TypeError, e:
import pdb;pdb.set_trace()
c_extract = r.type.c_extract_out(name, sub, check_input)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......
...@@ -360,6 +360,24 @@ class CudaNdarrayType(Type): ...@@ -360,6 +360,24 @@ class CudaNdarrayType(Type):
#print sio.getvalue() #print sio.getvalue()
return sio.getvalue() return sio.getvalue()
def c_extract_out(self, name, sub, check_input=True, check_broadcast=True):
""" To allow the hack to skip check_broadcast.
"""
return """
if (py_%(name)s == Py_None)
{
%(c_init_code)s
}
else
{
%(c_extract_code)s
}
""" % dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input,
check_broadcast))
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return """ return """
//std::cerr << "cleanup " << py_%(name)s << " " << %(name)s << "\\n"; //std::cerr << "cleanup " << py_%(name)s << " " << %(name)s << "\\n";
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论