提交 618ed85c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make advsub1 work with type context.

上级 79153fde
from __future__ import print_function from __future__ import print_function
import copy import copy
import os
import numpy import numpy
import theano import theano
from theano import tensor, gof, config from theano import tensor, gof
from theano.gof.utils import MethodNotDefined
from six.moves import StringIO from six.moves import StringIO
from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list
import theano.tensor.inplace import theano.tensor.inplace
...@@ -19,7 +17,8 @@ except ImportError: ...@@ -19,7 +17,8 @@ except ImportError:
pass pass
from .type import GpuArrayType from .type import GpuArrayType
from .basic_ops import (as_gpuarray_variable, HideC, GpuKernelBase, Kernel) from .basic_ops import (as_gpuarray_variable, HideC, GpuKernelBase, Kernel,
infer_context_name)
from .elemwise import GpuElemwise from .elemwise import GpuElemwise
...@@ -321,7 +320,7 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor): ...@@ -321,7 +320,7 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor):
%(view_ndim)s, %(view_ndim)s,
dims, dims,
xview_strides, xview_strides,
pygpu_default_context(), %(x)s->ctx,
1, 1,
(PyObject *)%(x)s, (PyObject *)%(x)s,
(PyObject *)&PyGpuArrayType); (PyObject *)&PyGpuArrayType);
...@@ -567,16 +566,16 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1): ...@@ -567,16 +566,16 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
only avail on compute capability 2.0 and more recent. only avail on compute capability 2.0 and more recent.
""" """
_f16_ok = True _f16_ok = True
def make_node(self, x, y, ilist): def make_node(self, x, y, ilist):
"""It defer from GpuAdvancedIncSubtensor1 in that it make sure """It defer from GpuAdvancedIncSubtensor1 in that it make sure
the index are of type long. the index are of type long.
""" """
x_ = as_gpuarray_variable(x) ctx_name = infer_context_name(x, y, ilist)
y_ = as_gpuarray_variable(y) x_ = as_gpuarray_variable(x, ctx_name)
ilist_ = as_gpuarray_variable(ilist) y_ = as_gpuarray_variable(y, ctx_name)
ilist_ = as_gpuarray_variable(ilist, ctx_name)
assert x_.type.dtype == y_.type.dtype assert x_.type.dtype == y_.type.dtype
assert x_.type.ndim >= y_.type.ndim assert x_.type.ndim >= y_.type.ndim
...@@ -599,32 +598,24 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1): ...@@ -599,32 +598,24 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
return gof.Apply(self, [x_, y_, ilist_], [x_.type()]) return gof.Apply(self, [x_, y_, ilist_], [x_.type()])
def get_context(self, node):
return self.node.outputs[0].type.context
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (6,)
def c_headers(self): def c_headers(self):
if pygpu.get_default_context().kind == 'opencl': return ['<numpy_compat.h>', '<gpuarray_helper.h>',
raise MethodNotDefined('cuda only')
return ['cuda.h', '<numpy_compat.h>', '<gpuarray_helper.h>',
'<gpuarray/types.h>'] '<gpuarray/types.h>']
def c_header_dirs(self):
if pygpu.get_default_context().kind == 'opencl':
raise MethodNotDefined('cuda only')
cuda_root = config.cuda.root
res = [os.path.dirname(__file__)]
if cuda_root:
res.append(os.path.join(cuda_root, 'include'))
return res
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
active_device_no = theano.sandbox.cuda.active_device_number() ctx = self.get_context(node)
device_properties = theano.sandbox.cuda.device_properties if ctx.kind != 'cuda':
compute_capability = device_properties(active_device_no)['major'] raise NotImplementedError("cuda only")
if ((self.set_instead_of_inc) or if (self.set_instead_of_inc or
(node.inputs[0].ndim != node.inputs[1].ndim) or node.inputs[0].ndim != node.inputs[1].ndim or
(node.inputs[0].ndim != 2) or node.inputs[0].ndim != 2 or
(compute_capability < 2)): ctx.bin_id[-2] < '2'):
raise NotImplementedError("This case does not have C code yet.") raise NotImplementedError("This case does not have C code yet.")
x = inputs[0] x = inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论