提交 47270716 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Be safer with the support code (in peculiar don't try to bake the C++ complex…

Be safer with the support code (in peculiar don't try to bake the C++ complex struct in the kernel code). Also fix the call syntax for python <= 2.5
上级 75a4738b
...@@ -95,16 +95,23 @@ class GpuElemwise(Op): ...@@ -95,16 +95,23 @@ class GpuElemwise(Op):
sub=dict(fail='return;')) sub=dict(fail='return;'))
res.tag.kcode = kcode res.tag.kcode = kcode
support_code = ""
try: try:
support_code += self.scalar_op.c_support_code_apply(fake_node, 'kcode') code = self.scalar_op.c_support_code_apply(fake_node, 'kcode')
if code:
raise SupportCodeError()
except MethodNotDefined: except MethodNotDefined:
pass pass
support_code = ""
try: try:
support_code += self.scalar_op.c_support_code() support_code += self.scalar_op.c_support_code()
except MethodNotDefined: except MethodNotDefined:
pass pass
if support_code != "#define THEANO_MACRO_MOD(x,y) (x % y)":
# Avoid the C++ complex struct
raise SupportCodeError()
k = ElemwiseKernel(None, inps+outs, kcode, preamble=support_code) k = ElemwiseKernel(None, inps+outs, kcode, preamble=support_code)
res.tag.kernel = k res.tag.kernel = k
...@@ -114,7 +121,13 @@ class GpuElemwise(Op): ...@@ -114,7 +121,13 @@ class GpuElemwise(Op):
k = node.tag.kernel k = node.tag.kernel
outs = [ensure_out(o[0], inps[0]) for o in out] outs = [ensure_out(o[0], inps[0]) for o in out]
k.call_dimspec(*(inps+outs), broadcast=True) # the dict call is there to avoid syntax error in python <= 2.5
k(*(inps+outs), **dict(broadcast=True))
for o, og in zip(out, outs): for o, og in zip(out, outs):
o[0] = og o[0] = og
class SupportCodeError(Exception):
"""
We do not support certain things (such as the C++ complex struct)
"""
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论