提交 644dbb6c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Rewrite GpuElemwise to be more like the tensor version and much cleaner.

上级 d5bdbb4a
from itertools import izip
import numpy import numpy
from theano import Op, Apply, scalar from theano import Op, Apply, scalar
from theano.tensor.elemwise import Elemwise
try: try:
import pygpu
from pygpu.tools import ScalarArg, ArrayArg from pygpu.tools import ScalarArg, ArrayArg
from pygpu.elemwise import ElemwiseKernel from pygpu.elemwise import ElemwiseKernel
except ImportError: except ImportError:
pass pass
from basic_ops import as_gpuarray_variable from theano.sandbox.gpuarray.basic_ops import as_gpuarray_variable
from type import GpuArrayType from theano.sandbox.gpuarray.type import GpuArrayType
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
...@@ -21,124 +25,140 @@ def make_argument(v, name): ...@@ -21,124 +25,140 @@ def make_argument(v, name):
else: else:
return ArrayArg(numpy.dtype(v.type.dtype), name) return ArrayArg(numpy.dtype(v.type.dtype), name)
def ensure_out(o, ref): def ensure_allocated(storage, shape, dtype):
if o is None: odat = storage[0]
return ref._empty_like_me() if odat is not None:
else: if odat.shape != shape:
return o # It is unsafe to try to resize odat,
# we have to allocate output storage.
class GpuElemwise(Op): odat = None
if odat is None:
odat = pygpu.empty(shape, dtype=dtype)
storage[0] = odat
return odat
def as_C_string_const(s):
return '\n'.join('"%s\\n"' % (l.replace('"', '\\"'))
for l in s.split('\n'))
class GpuElemwise(Elemwise):
nin = property(lambda self: self.scalar_op.nin) nin = property(lambda self: self.scalar_op.nin)
nout = property(lambda self: self.scalar_op.nout) nout = property(lambda self: self.scalar_op.nout)
def __init__(self, scalar_op): def __init__(self, scalar_op, name=None, nfunc_spec=None):
self.scalar_op = scalar_op # We do not support inplace since it is a lie anyway
self.destroy_map = {} # (the scalar_op code will never modify anything inplace)
Elemwise.__init__(self, scalar_op, inplace_pattern=None, name=name,
def __getstate__(self): nfunc_spec=nfunc_spec)
d = copy.copy(self.__dict__)
d.pop('__epydoc_asRoutine', None)
d.pop('_hashval')
return d
def __setstate__(self, d):
self.__dict__.update(d)
self._rehash()
def __eq__(self, other):
return (type(self) == type(other) and
self.scalar_op == other.scalar_op)
def __hash__(self):
return hash(type(self)) ^ hash(self.scalar_op)
def __str__(self): def __str__(self):
return "GpuElemwise{%s}(gpuarray)" % (self.scalar_op,) if self.name is not None:
return self.name
return "GpuElemwise{%s}<gpuarray>" % (self.scalar_op,)
def make_node(self, *inputs): def make_node(self, *inputs):
_inputs = [as_gpuarray_variable(i) for i in inputs] res = Elemwise.make_node(self, *inputs)
if self.nin > 0 and len(_inputs) != self.nin: outputs = [GpuArrayType(broadcastable=o.type.broadcastable,
raise TypeError("Wrong argument count", (self.nin, len(_inputs))) dtype=o.type.dtype)() for o in res.outputs]
for i in _inputs[1:]: inputs = [as_gpuarray_variable(i) for i in inputs]
if i.type.ndim != inputs[0].type.ndim: res = Apply(self, inputs, outputs)
raise TypeError('mismatched rank amongst inputs') # Try to generate the kernel to catch SupportCodeErrors
k = self.generate_kernel(res, 'test')
broadcastable = [] return res
for d in xrange(_inputs[0].type.ndim):
bcast_d = True
for i in _inputs:
if not i.type.broadcastable[d]:
bcast_d = False
break
broadcastable.append(bcast_d)
assert len(broadcastable) == _inputs[0].type.ndim
assert self.nout > 0
inps = [make_argument(i, 'i%d' % (n,)) for n, i in
enumerate(inputs)]
scal_ins = [scalar.Scalar(i.dtype) for i in inputs]
res = Apply(self, _inputs, def generate_kernel(self, node, nodename):
[GpuArrayType(o.dtype, broadcastable)() inps = [make_argument(i, 'i%d' % (n,)) for n, i in
for o in self.scalar_op.output_types(scal_ins)]) enumerate(node.inputs)]
scal_ins = [scalar.Scalar(i.dtype) for i in node.inputs]
outs = [make_argument(o, 'o%d' % (n,)) for n, o in outs = [make_argument(o, 'o%d' % (n,)) for n, o in
enumerate(res.outputs)] enumerate(node.outputs)]
scal_out = [scalar.Scalar(o.dtype) for o in res.outputs] scal_out = [scalar.Scalar(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins], fake_node = Apply(self.scalar_op, [i() for i in scal_ins],
[o() for o in scal_out]) [o() for o in scal_out])
kcode = self.scalar_op.c_code(fake_node, 'kcode',
[i.expr() for i in inps],
[o.expr() for o in outs],
sub=dict(fail='return;'))
res.tag.kcode = kcode
# Translate types for scalar composite ops (except complex).
support_code = """
#define npy_float64 ga_double
#define npy_float32 ga_float
#define npy_uint8 ga_ubyte
#define npy_int8 ga_byte
#define npy_uint16 ga_ushort
#define npy_int16 ga_short
#define npy_uint32 ga_uint
#define npy_int32 ga_int
#define npy_uint64 ga_ulong
#define npy_int64 ga_long
"""
try: try:
code = self.scalar_op.c_support_code_apply(fake_node, 'kcode') code = self.scalar_op.c_support_code_apply(fake_node, nodename)
if code: if code:
raise SupportCodeError() raise SupportCodeError(code)
except MethodNotDefined: except MethodNotDefined:
pass pass
support_code = "" support_code = ""
try: try:
support_code += self.scalar_op.c_support_code() support_code = self.scalar_op.c_support_code()
except MethodNotDefined: except MethodNotDefined:
pass pass
if support_code != "#define THEANO_MACRO_MOD(x,y) (x % y)": if (support_code != "#define THEANO_MACRO_MOD(x,y) (x % y)" and
# Avoid the C++ complex struct support_code != ""):
raise SupportCodeError() # The macro is fine, the C++ struct is not.
raise SupportCodeError(support_code)
k = ElemwiseKernel(None, inps+outs, kcode, preamble=support_code) kop = self.scalar_op.c_code(fake_node, nodename+'_scalar',
res.tag.kernel = k [i.name+'[i]' for i in inps],
[o.name+'[i]' for o in outs],
dict(fail='return;'))
return res # Translate types for scalar composite ops (except complex).
support_code += """
#define npy_float64 ga_double
#define npy_float32 ga_float
#define npy_uint8 ga_ubyte
#define npy_int8 ga_byte
#define npy_uint16 ga_ushort
#define npy_int16 ga_short
#define npy_uint32 ga_uint
#define npy_int32 ga_int
#define npy_uint64 ga_ulong
#define npy_int64 ga_long
"""
return ElemwiseKernel(None, inps+outs, kop, preamble=support_code)
def c_support_code_apply(self, node, nodename):
# This is useless by itself, but will serve an eventual c_code
# implementation
k = self.generate_kernel(node, nodename)
nd = node.inputs[0].type.ndim
res = []
for i in range(1, nd):
var = "static const char %s_%s[] = " % (nodename, str(i))
res.append(var + as_C_string_const(k.render_basic(i)) + ';')
res.append("static const gpukernel *%s_%s_k = NULL;" % (nodename,
str(i)))
var = "static const char %s_c[] = " % (nodename,)
res.append(var + as_C_string_const(k.contig_src) + ';')
res.append("static const gpukernel *%s_c_k = NULL;" % (nodename,))
return '\n'.join(res)
def c_code(self, *args):
# do not pick up the Elemwise version
raise MethodNotDefined('c_code')
def perform(self, node, inputs, output_storage):
# Try to reuse the kernel from a previous call to hopefully
# avoid recompiling
if not hasattr(node, '_cache_elemwise_k'):
node._cache_elemwise_k = self.generate_kernel(node, "kcode")
out_shape = []
for values in izip(*[input.shape for input in inputs]):
if any(v == 0 for v in values):
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)
def perform(self, node, inps, out): outs = [ensure_allocated(storage, out_shape, output.type.dtype)
k = node.tag.kernel for output, storage in izip(node.outputs, output_storage)]
outs = [ensure_out(o[0], inps[0]) for o in out]
# the dict call is there to avoid syntax error in python <= 2.5 # the dict call is there to avoid a syntax error in python < 2.6
k(*(inps+outs), **dict(broadcast=True)) node._cache_elemwise_k(*(inputs+outs), **dict(broadcast=True))
for o, og in zip(out, outs):
o[0] = og
class SupportCodeError(Exception): class SupportCodeError(Exception):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论