提交 567d195a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update C code to use acc_dtype if needed.

If acc_dtype == output.dtype, nothing changes.
上级 e7f94330
......@@ -1350,6 +1350,14 @@ class CAReduce(Op):
idtype = input.type.dtype_specs()[1]
odtype = output.type.dtype_specs()[1]
if hasattr(self, 'acc_dtype'):
acc_type = TensorType(
broadcastable=node.outputs[0].broadcastable,
dtype=self.acc_dtype)
adtype = acc_type.dtype_specs()[1]
else:
adtype = odtype
axis = self.axis
if axis is None:
axis = range(len(input.type.broadcastable))
......@@ -1367,13 +1375,25 @@ class CAReduce(Op):
for i, (input, iname) in enumerate(izip(node.inputs, inames)):
sub['lv%i' % i] = iname
decl = cgen.make_declare([order], [idtype], sub)
decl = ""
if adtype != odtype:
# Create an accumulator variable different from the output
aname = "acc"
decl = acc_type.c_declare(aname, sub)
decl += acc_type.c_init(aname, sub)
else:
# the output is the accumulator variable
aname = oname
decl += cgen.make_declare([order], [idtype], sub)
checks = cgen.make_checks([order], [idtype], sub)
alloc = ""
i += 1
sub['lv%i' % i] = oname
sub['olv'] = oname
# Allocate output buffer
alloc += cgen.make_declare(
[range(nnested) + ['x'] * len(axis)],
[odtype], dict(sub, lv0=oname))
......@@ -1382,6 +1402,19 @@ class CAReduce(Op):
[range(nnested) + ['x'] * len(axis)],
[odtype], dict(sub, lv0=oname))
if adtype != odtype:
# Allocate accumulation buffer
sub['lv%i' % i] = aname
sub['olv'] = aname
alloc += cgen.make_declare(
[range(nnested) + ['x'] * len(axis)],
[adtype], dict(sub, lv0=aname))
alloc += cgen.make_alloc([order1], adtype, sub)
alloc += cgen.make_checks(
[range(nnested) + ['x'] * len(axis)],
[adtype], dict(sub, lv0=aname))
if hasattr(self.scalar_op, 'identity'):
identity = self.scalar_op.identity
elif self.scalar_op in [scalar.maximum, scalar.minimum]:
......@@ -1425,7 +1458,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
task0_decl = (
"%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
"%(name)s_i = %(identity)s;"
% dict(dtype=odtype, name=onames[0], identity=identity))
% dict(dtype=adtype, name=aname, identity=identity))
task1_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
% dict(dtype=idtype, name=inames[0]))
......@@ -1438,8 +1471,8 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
[Scalar(dtype=output.type.dtype)()
for input in node.outputs]),
None,
["%s_i" % onames[0], "%s_i" % inames[0]],
["%s_i" % onames[0]],
["%s_i" % aname, "%s_i" % inames[0]],
["%s_i" % aname],
sub)
code1 = """
{
......@@ -1461,8 +1494,16 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
all_code = [task0_decl + code1]
loop = cgen.make_loop(
[order, range(nnested) + ['x'] * len(axis)],
[idtype, odtype], all_code, sub)
return decl, checks, alloc, loop
[idtype, adtype], all_code, sub)
end = ""
if adtype != odtype:
end = """
PyArray_CopyInto(%(oname)s, %(aname)s);
""" % dict(oname=oname, aname=aname)
end += acc_type.c_cleanup(aname, sub)
return decl, checks, alloc, loop, end
def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(node, name, inames, onames, sub))
......@@ -1473,7 +1514,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
return ['<vector>', '<algorithm>']
def c_code_cache_version_apply(self, node):
version = [4] # the version corresponding to the c code in this Op
version = [5] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论