提交 87e06fdb authored 作者: Thomas George's avatar Thomas George

GpuGer now uses params instead uses params instead of class attribute for inplace

上级 83debd73
...@@ -6,7 +6,8 @@ import theano ...@@ -6,7 +6,8 @@ import theano
from theano import Apply, config, Op from theano import Apply, config, Op
from theano.compile import optdb from theano.compile import optdb
from theano.gof import LocalOptGroup from theano.gof import LocalOptGroup, ParamsType
from theano.scalar import Scalar
from theano.tensor.basic import as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
...@@ -257,6 +258,7 @@ class GpuGer(BlasOp): ...@@ -257,6 +258,7 @@ class GpuGer(BlasOp):
Ger on the GPU. Ger on the GPU.
""" """
params_type = ParamsType(inplace=Scalar('bool'))
__props__ = ('inplace',) __props__ = ('inplace',)
def __init__(self, inplace=False): def __init__(self, inplace=False):
...@@ -282,9 +284,9 @@ class GpuGer(BlasOp): ...@@ -282,9 +284,9 @@ class GpuGer(BlasOp):
assert y.ndim == 1 assert y.ndim == 1
return Apply(self, [A, alpha, x, y], [A.type()]) return Apply(self, [A, alpha, x, y], [A.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out, params):
A, alpha, x, y = inp A, alpha, x, y = inp
inplace = self.inplace inplace = params['inplace']
if inplace and not A.flags.forc: if inplace and not A.flags.forc:
inplace = False inplace = False
out[0][0] = blas.ger(alpha, x, y, A, out[0][0] = blas.ger(alpha, x, y, A,
...@@ -292,33 +294,23 @@ class GpuGer(BlasOp): ...@@ -292,33 +294,23 @@ class GpuGer(BlasOp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], A=inp[0], alpha=inp[1], x=inp[2], y=inp[3], vars = dict(out=out[0], A=inp[0], alpha=inp[1], x=inp[2], y=inp[3],
fail=sub['fail'], name=name) fail=sub['fail'], name=name, params=sub['params'])
if self.inplace: code = """
code = """ if (%(params)s->inplace || !GpuArray_ISONESEGMENT(&%(A)s->ga)) {
if (!GpuArray_ISONESEGMENT(&%(A)s->ga)) { %(out)s = theano_try_copy(%(out)s, %(A)s);
%(out)s = theano_try_copy(%(out)s, %(A)s); if (%(out)s == NULL) {
if (%(out)s == NULL) { %(fail)s
%(fail)s }
} } else {
} else { Py_XDECREF(%(out)s);
Py_XDECREF(%(out)s); %(out)s = %(A)s;
%(out)s = %(A)s; Py_INCREF(%(out)s);
Py_INCREF(%(out)s); }
} if (pygpu_blas_rger(((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
""" % vars %(x)s, %(y)s, %(out)s, 0) == -1) {
else: %(fail)s
code = """ }
%(out)s = theano_try_copy(%(out)s, %(A)s); """ % vars
if (%(out)s == NULL) {
%(fail)s
}
""" % vars
code += """
if (pygpu_blas_rger(((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(x)s, %(y)s, %(out)s, 0) == -1) {
%(fail)s
}
""" % vars
if config.gpuarray.sync: if config.gpuarray.sync:
code += """ code += """
GpuArray_sync(&%(out)s->ga); GpuArray_sync(&%(out)s->ga);
...@@ -326,7 +318,7 @@ class GpuGer(BlasOp): ...@@ -326,7 +318,7 @@ class GpuGer(BlasOp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
gpuger_no_inplace = GpuGer(inplace=False) gpuger_no_inplace = GpuGer(inplace=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论