提交 d4f7d749 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some small comments from review.

上级 b33b4462
......@@ -72,7 +72,10 @@ The last place where you might need the context is in the C
initialization code. For that you will have to use the :ref:`params
<extending_op_params>`. The params type should be
:class:`theano.gpuarray.type.gpu_context_type` and the params object
should be a context object from one of your input variables.
should be a context object from one of your input variables::
def get_params(self, node):
return node.inputs[0].type.context
If you don't have any input variables on the GPU you can follow the
the example of :class:`theano.gpuarray.basic_ops.GpuFromHost` or
......
......@@ -162,10 +162,13 @@ class Kernel(object):
dictionary of flags
codevar: str
the name of the variable for the code object.
(defaults to 'kcode_' + name)
binvar: str
the name of the variable for the binary object.
(defaults to 'kbin_' + name)
objvar: str
the name of the variable for the kernel object.
(defaults to 'k_' + name)
"""
......@@ -362,7 +365,7 @@ class GpuKernelBase(object):
return (4, self.get_params(node).bin_id)
def fwds(name):
def forward_string_meth(name):
def f(*args):
res = getattr(GpuKernelBase, name)(*args)
try:
......@@ -386,11 +389,11 @@ class CGpuKernelBase(COp, GpuKernelBase):
kernel_re = re.compile(r'^#kernel ([a-zA-Z_].*?)$', re.MULTILINE)
c_support_code = fwds('c_support_code')
c_support_code_apply = fwds('c_support_code_apply')
c_support_code_struct = fwds('c_support_code_struct')
c_init_code_struct = fwds('c_init_code_struct')
c_cleanup_code_struct = fwds('c_cleanup_code_struct')
c_support_code = forward_string_meth('c_support_code')
c_support_code_apply = forward_string_meth('c_support_code_apply')
c_support_code_struct = forward_string_meth('c_support_code_struct')
c_init_code_struct = forward_string_meth('c_init_code_struct')
c_cleanup_code_struct = forward_string_meth('c_cleanup_code_struct')
def _type_macros(self, node):
define_template = "#define %s %s"
......@@ -428,7 +431,7 @@ class CGpuKernelBase(COp, GpuKernelBase):
res = []
while n < len(split):
kspec = split[n]
kcode = split[n+1]
kcode = split[n + 1]
splt2 = kspec.split(':')
if len(splt2) != 3:
raise ValueError("Bad kernel spec: %s" % (kspec,))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论