提交 df16dbcc authored 作者: Frederic Bastien's avatar Frederic Bastien

Make code more readable.

上级 8976c742
...@@ -423,12 +423,11 @@ class GpuGemmBatch(BlasOp): ...@@ -423,12 +423,11 @@ class GpuGemmBatch(BlasOp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], C=inp[0], alpha=inp[1], A=inp[2], B=inp[3], vars = dict(out=out[0], C=inp[0], alpha=inp[1], A=inp[2], B=inp[3],
beta=inp[4], fail=sub['fail'], name=name) beta=inp[4], inplace=int(self.inplace),
fail=sub['fail'], name=name)
code = """ code = """
int err; int err;
""" if (%(inplace)s){
if self.inplace:
code += """
if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) { if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) {
%(out)s = theano_try_copy(%(out)s, %(C)s); %(out)s = theano_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) { if (%(out)s == NULL) {
...@@ -439,15 +438,12 @@ class GpuGemmBatch(BlasOp): ...@@ -439,15 +438,12 @@ class GpuGemmBatch(BlasOp):
%(out)s = %(C)s; %(out)s = %(C)s;
Py_INCREF(%(out)s); Py_INCREF(%(out)s);
} }
""" % vars } else {
else:
code += """
%(out)s = theano_try_copy(%(out)s, %(C)s); %(out)s = theano_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) { if (%(out)s == NULL) {
%(fail)s %(fail)s
} }
""" % vars }
code += """
err = GpuArray_rgemmBatch_3d( err = GpuArray_rgemmBatch_3d(
cb_no_trans, cb_no_trans, cb_no_trans, cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0], ((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
...@@ -467,7 +463,7 @@ class GpuGemmBatch(BlasOp): ...@@ -467,7 +463,7 @@ class GpuGemmBatch(BlasOp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
gpugemmbatch_no_inplace = GpuGemmBatch(inplace=False) gpugemmbatch_no_inplace = GpuGemmBatch(inplace=False)
gpugemmbatch_inplace = GpuGemmBatch(inplace=True) gpugemmbatch_inplace = GpuGemmBatch(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论