提交 df16dbcc authored 作者: Frederic Bastien's avatar Frederic Bastien

Make code more readable.

上级 8976c742
......@@ -423,12 +423,11 @@ class GpuGemmBatch(BlasOp):
def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], C=inp[0], alpha=inp[1], A=inp[2], B=inp[3],
beta=inp[4], fail=sub['fail'], name=name)
beta=inp[4], inplace=int(self.inplace),
fail=sub['fail'], name=name)
code = """
int err;
"""
if self.inplace:
code += """
if (%(inplace)s){
if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) {
%(out)s = theano_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) {
......@@ -439,15 +438,12 @@ class GpuGemmBatch(BlasOp):
%(out)s = %(C)s;
Py_INCREF(%(out)s);
}
""" % vars
else:
code += """
} else {
%(out)s = theano_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) {
%(fail)s
}
""" % vars
code += """
}
err = GpuArray_rgemmBatch_3d(
cb_no_trans, cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
......@@ -467,7 +463,7 @@ class GpuGemmBatch(BlasOp):
return code
def c_code_cache_version(self):
return (1,)
return (2,)
gpugemmbatch_no_inplace = GpuGemmBatch(inplace=False)
gpugemmbatch_inplace = GpuGemmBatch(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论