提交 073b822d authored 作者: Thomas George's avatar Thomas George

GpuAlloc now uses params

上级 d003afce
...@@ -8,6 +8,7 @@ import theano ...@@ -8,6 +8,7 @@ import theano
from theano import Op, Apply, Type, Variable from theano import Op, Apply, Type, Variable
from theano import tensor, config from theano import tensor, config
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.scalar import bool as bool_t
from theano.tensor.basic import ( from theano.tensor.basic import (
Alloc, AllocEmpty, alloc_validate_shape, Join, Split) Alloc, AllocEmpty, alloc_validate_shape, Join, Split)
...@@ -808,14 +809,15 @@ class GpuAlloc(HideC, Alloc): ...@@ -808,14 +809,15 @@ class GpuAlloc(HideC, Alloc):
__props__ = ('memset_0', 'context_name') __props__ = ('memset_0', 'context_name')
_f16_ok = True _f16_ok = True
params_type = gpu_context_type params_type = ParamsType(context=gpu_context_type, memset_0=bool_t)
def __init__(self, context_name, memset_0=False): def __init__(self, context_name, memset_0=False):
self.context_name = context_name self.context_name = context_name
self.memset_0 = memset_0 self.memset_0 = memset_0
def get_params(self, node): def get_params(self, node):
return get_context(self.context_name) return self.params_type.get_params(context=get_context(self.context_name),
memset_0=self.memset_0)
def __str__(self): def __str__(self):
# Hide the memset parameter when not used to prevent confusion. # Hide the memset parameter when not used to prevent confusion.
...@@ -837,15 +839,15 @@ class GpuAlloc(HideC, Alloc): ...@@ -837,15 +839,15 @@ class GpuAlloc(HideC, Alloc):
def c_headers(self): def c_headers(self):
return ['<numpy_compat.h>'] return ['<numpy_compat.h>']
def perform(self, node, inputs, outs, ctx): def perform(self, node, inputs, outs, params):
out, = outs out, = outs
v = inputs[0] v = inputs[0]
sh = tuple(map(int, inputs[1:])) sh = tuple(map(int, inputs[1:]))
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
if self.memset_0: if self.memset_0:
out[0] = gpuarray.zeros(sh, dtype=v.dtype, context=ctx) out[0] = gpuarray.zeros(sh, dtype=v.dtype, context=params.context)
else: else:
out[0] = gpuarray.empty(sh, dtype=v.dtype, context=ctx) out[0] = gpuarray.empty(sh, dtype=v.dtype, context=params.context)
out[0][...] = v out[0][...] = v
else: else:
out[0][...] = v out[0][...] = v
...@@ -857,7 +859,6 @@ class GpuAlloc(HideC, Alloc): ...@@ -857,7 +859,6 @@ class GpuAlloc(HideC, Alloc):
ndim = len(inp[1:]) ndim = len(inp[1:])
zz, = out zz, = out
memset_0 = int(self.memset_0)
code = """ code = """
int i; int i;
size_t %(name)s_shape[%(ndim)s]; size_t %(name)s_shape[%(ndim)s];
...@@ -875,12 +876,12 @@ class GpuAlloc(HideC, Alloc): ...@@ -875,12 +876,12 @@ class GpuAlloc(HideC, Alloc):
for (i = 0; i < %(ndim)s; i++) for (i = 0; i < %(ndim)s; i++)
need_new_out |= %(zz)s->ga.dimensions[i] != %(name)s_shape[i]; need_new_out |= %(zz)s->ga.dimensions[i] != %(name)s_shape[i];
if (need_new_out && (%(memset_0)s)) { if (need_new_out && (%(params)s->memset_0)) {
//pygpu_zeros can be faster then empty followed by memset. //pygpu_zeros can be faster then empty followed by memset.
Py_XDECREF(%(zz)s); Py_XDECREF(%(zz)s);
%(zz)s = pygpu_zeros(%(ndim)s, %(name)s_shape, %(zz)s = pygpu_zeros(%(ndim)s, %(name)s_shape,
%(vv)s->ga.typecode, GA_C_ORDER, %(vv)s->ga.typecode, GA_C_ORDER,
%(ctx)s, Py_None); %(params)s->context, Py_None);
if (!%(zz)s) { if (!%(zz)s) {
%(fail)s %(fail)s
} }
...@@ -889,12 +890,12 @@ class GpuAlloc(HideC, Alloc): ...@@ -889,12 +890,12 @@ class GpuAlloc(HideC, Alloc):
Py_XDECREF(%(zz)s); Py_XDECREF(%(zz)s);
%(zz)s = pygpu_empty(%(ndim)s, %(name)s_shape, %(zz)s = pygpu_empty(%(ndim)s, %(name)s_shape,
%(vv)s->ga.typecode, GA_C_ORDER, %(vv)s->ga.typecode, GA_C_ORDER,
%(ctx)s, Py_None); %(params)s->context, Py_None);
if (!%(zz)s) { if (!%(zz)s) {
%(fail)s %(fail)s
} }
} }
if (%(memset_0)s && GpuArray_ISONESEGMENT(&%(zz)s->ga)) if (%(params)s->memset_0 && GpuArray_ISONESEGMENT(&%(zz)s->ga))
{ {
int err = GpuArray_memset(&%(zz)s->ga, 0); int err = GpuArray_memset(&%(zz)s->ga, 0);
if (err != GA_NO_ERROR) if (err != GA_NO_ERROR)
...@@ -912,8 +913,8 @@ class GpuAlloc(HideC, Alloc): ...@@ -912,8 +913,8 @@ class GpuAlloc(HideC, Alloc):
%(fail)s %(fail)s
} }
} }
""" % dict(name=name, ndim=ndim, zz=zz, vv=vv, ctx=sub['params'], """ % dict(name=name, ndim=ndim, zz=zz, vv=vv, params=sub['params'],
fail=sub['fail'], memset_0=memset_0) fail=sub['fail'])
if config.gpuarray.sync: if config.gpuarray.sync:
code += "GpuArray_sync(&%(zz)s->ga);" % dict(zz=zz) code += "GpuArray_sync(&%(zz)s->ga);" % dict(zz=zz)
...@@ -921,7 +922,7 @@ class GpuAlloc(HideC, Alloc): ...@@ -921,7 +922,7 @@ class GpuAlloc(HideC, Alloc):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def do_constant_folding(self, node): def do_constant_folding(self, node):
from . import subtensor, blas from . import subtensor, blas
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论