提交 567d195a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update C code to use acc_dtype if needed.

If acc_dtype == output.dtype, nothing changes.
上级 e7f94330
...@@ -1350,6 +1350,14 @@ class CAReduce(Op): ...@@ -1350,6 +1350,14 @@ class CAReduce(Op):
idtype = input.type.dtype_specs()[1] idtype = input.type.dtype_specs()[1]
odtype = output.type.dtype_specs()[1] odtype = output.type.dtype_specs()[1]
if hasattr(self, 'acc_dtype'):
acc_type = TensorType(
broadcastable=node.outputs[0].broadcastable,
dtype=self.acc_dtype)
adtype = acc_type.dtype_specs()[1]
else:
adtype = odtype
axis = self.axis axis = self.axis
if axis is None: if axis is None:
axis = range(len(input.type.broadcastable)) axis = range(len(input.type.broadcastable))
...@@ -1367,13 +1375,25 @@ class CAReduce(Op): ...@@ -1367,13 +1375,25 @@ class CAReduce(Op):
for i, (input, iname) in enumerate(izip(node.inputs, inames)): for i, (input, iname) in enumerate(izip(node.inputs, inames)):
sub['lv%i' % i] = iname sub['lv%i' % i] = iname
decl = cgen.make_declare([order], [idtype], sub) decl = ""
if adtype != odtype:
# Create an accumulator variable different from the output
aname = "acc"
decl = acc_type.c_declare(aname, sub)
decl += acc_type.c_init(aname, sub)
else:
# the output is the accumulator variable
aname = oname
decl += cgen.make_declare([order], [idtype], sub)
checks = cgen.make_checks([order], [idtype], sub) checks = cgen.make_checks([order], [idtype], sub)
alloc = "" alloc = ""
i += 1 i += 1
sub['lv%i' % i] = oname sub['lv%i' % i] = oname
sub['olv'] = oname sub['olv'] = oname
# Allocate output buffer
alloc += cgen.make_declare( alloc += cgen.make_declare(
[range(nnested) + ['x'] * len(axis)], [range(nnested) + ['x'] * len(axis)],
[odtype], dict(sub, lv0=oname)) [odtype], dict(sub, lv0=oname))
...@@ -1382,6 +1402,19 @@ class CAReduce(Op): ...@@ -1382,6 +1402,19 @@ class CAReduce(Op):
[range(nnested) + ['x'] * len(axis)], [range(nnested) + ['x'] * len(axis)],
[odtype], dict(sub, lv0=oname)) [odtype], dict(sub, lv0=oname))
if adtype != odtype:
# Allocate accumulation buffer
sub['lv%i' % i] = aname
sub['olv'] = aname
alloc += cgen.make_declare(
[range(nnested) + ['x'] * len(axis)],
[adtype], dict(sub, lv0=aname))
alloc += cgen.make_alloc([order1], adtype, sub)
alloc += cgen.make_checks(
[range(nnested) + ['x'] * len(axis)],
[adtype], dict(sub, lv0=aname))
if hasattr(self.scalar_op, 'identity'): if hasattr(self.scalar_op, 'identity'):
identity = self.scalar_op.identity identity = self.scalar_op.identity
elif self.scalar_op in [scalar.maximum, scalar.minimum]: elif self.scalar_op in [scalar.maximum, scalar.minimum]:
...@@ -1425,7 +1458,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1425,7 +1458,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
task0_decl = ( task0_decl = (
"%(dtype)s& %(name)s_i = *%(name)s_iter;\n" "%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
"%(name)s_i = %(identity)s;" "%(name)s_i = %(identity)s;"
% dict(dtype=odtype, name=onames[0], identity=identity)) % dict(dtype=adtype, name=aname, identity=identity))
task1_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n" task1_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
% dict(dtype=idtype, name=inames[0])) % dict(dtype=idtype, name=inames[0]))
...@@ -1438,8 +1471,8 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1438,8 +1471,8 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
[Scalar(dtype=output.type.dtype)() [Scalar(dtype=output.type.dtype)()
for input in node.outputs]), for input in node.outputs]),
None, None,
["%s_i" % onames[0], "%s_i" % inames[0]], ["%s_i" % aname, "%s_i" % inames[0]],
["%s_i" % onames[0]], ["%s_i" % aname],
sub) sub)
code1 = """ code1 = """
{ {
...@@ -1461,8 +1494,16 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1461,8 +1494,16 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
all_code = [task0_decl + code1] all_code = [task0_decl + code1]
loop = cgen.make_loop( loop = cgen.make_loop(
[order, range(nnested) + ['x'] * len(axis)], [order, range(nnested) + ['x'] * len(axis)],
[idtype, odtype], all_code, sub) [idtype, adtype], all_code, sub)
return decl, checks, alloc, loop
end = ""
if adtype != odtype:
end = """
PyArray_CopyInto(%(oname)s, %(aname)s);
""" % dict(oname=oname, aname=aname)
end += acc_type.c_cleanup(aname, sub)
return decl, checks, alloc, loop, end
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(node, name, inames, onames, sub)) code = "\n".join(self._c_all(node, name, inames, onames, sub))
...@@ -1473,7 +1514,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1473,7 +1514,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
return ['<vector>', '<algorithm>'] return ['<vector>', '<algorithm>']
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
version = [4] # the version corresponding to the c code in this Op version = [5] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(self.scalar_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论