提交 c03ed9b8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed elemwise's behavior with duplicate inputs

上级 77a11909
...@@ -124,6 +124,14 @@ class _test_Broadcast(unittest.TestCase): ...@@ -124,6 +124,14 @@ class _test_Broadcast(unittest.TestCase):
zv = xv + yv zv = xv + yv
assert (f(xv, yv) == zv).all() assert (f(xv, yv) == zv).all()
def test_same_inputs(self):
x = modes.build(Tensor('float64', [0, 0], name = 'x'))
e = Broadcast(Add, (x, x)).out
f = gof.CLinker(env([x], [e])).make_function(inplace = False)
xv = numpy.random.rand(2, 2)
zv = xv + xv
assert (f(xv) == zv).all()
class _test_CAReduce(unittest.TestCase): class _test_CAReduce(unittest.TestCase):
......
...@@ -241,6 +241,7 @@ class Broadcast(Op, Destroyer): ...@@ -241,6 +241,7 @@ class Broadcast(Op, Destroyer):
return ret return ret
def grad(self, inputs, ograds): def grad(self, inputs, ograds):
ograds = map(astensor, ograds)
shadow = self.shadow shadow = self.shadow
scalar_ograds = [Scalar(dtype = ograd.dtype) for ograd in ograds] scalar_ograds = [Scalar(dtype = ograd.dtype) for ograd in ograds]
scalar_igrads = shadow.grad(shadow.inputs, scalar_ograds) scalar_igrads = shadow.grad(shadow.inputs, scalar_ograds)
...@@ -320,11 +321,17 @@ class Broadcast(Op, Destroyer): ...@@ -320,11 +321,17 @@ class Broadcast(Op, Destroyer):
self.ufunc(*([input.data for input in self.inputs] + output_storage)) self.ufunc(*([input.data for input in self.inputs] + output_storage))
def _c_all(self, inames, onames, sub): def _c_all(self, inames, onames, sub):
_inames = inames
_onames = onames
inames = gof.utils.uniq(inames)
inputs = gof.utils.uniq(self.inputs)
defines = "" defines = ""
undefs = "" undefs = ""
dmap = self.destroy_map() dmap = self.destroy_map()
idtypes = [input.dtype_specs()[1] for input in self.inputs] idtypes = [input.dtype_specs()[1] for input in inputs]
real = zip(*[(r, s, r.dtype_specs()[1]) real = zip(*[(r, s, r.dtype_specs()[1])
for r, s in zip(self.outputs, onames) if r not in dmap]) for r, s in zip(self.outputs, onames) if r not in dmap])
...@@ -340,10 +347,10 @@ class Broadcast(Op, Destroyer): ...@@ -340,10 +347,10 @@ class Broadcast(Op, Destroyer):
else: else:
aliased_outputs, aliased_onames = [], [] aliased_outputs, aliased_onames = [], []
orders = [[x and 'x' or i for i, x in enumerate(input.broadcastable)] for input in self.inputs] orders = [[x and 'x' or i for i, x in enumerate(input.broadcastable)] for input in inputs]
nnested = len(orders[0]) nnested = len(orders[0])
sub = dict(sub) sub = dict(sub)
for i, (input, iname) in enumerate(zip(self.inputs, inames)): for i, (input, iname) in enumerate(zip(inputs, inames)):
sub['lv%i' % i] = iname sub['lv%i' % i] = iname
decl = cgen.make_declare(orders, idtypes, sub) decl = cgen.make_declare(orders, idtypes, sub)
checks = cgen.make_checks(orders, idtypes, sub) checks = cgen.make_checks(orders, idtypes, sub)
...@@ -358,7 +365,7 @@ class Broadcast(Op, Destroyer): ...@@ -358,7 +365,7 @@ class Broadcast(Op, Destroyer):
alloc += cgen.make_checks([range(nnested)], [odtype], dict(sub, lv0 = oname)) alloc += cgen.make_checks([range(nnested)], [odtype], dict(sub, lv0 = oname))
for output, oname in zip(aliased_outputs, aliased_onames): for output, oname in zip(aliased_outputs, aliased_onames):
iname = inames[self.inputs.index(dmap[output][0])] iname = inames[inputs.index(dmap[output][0])]
alloc += """ alloc += """
if (%(oname)s) { if (%(oname)s) {
Py_XDECREF(%(oname)s); Py_XDECREF(%(oname)s);
...@@ -368,8 +375,8 @@ class Broadcast(Op, Destroyer): ...@@ -368,8 +375,8 @@ class Broadcast(Op, Destroyer):
""" % locals() """ % locals()
defines += "#define %(oname)s_i %(iname)s_i" % locals() defines += "#define %(oname)s_i %(iname)s_i" % locals()
undefs += "#undef %(oname)s_i" % locals() undefs += "#undef %(oname)s_i" % locals()
task_code = self.shadow.c_code(["%s_i" % s for s in inames], task_code = self.shadow.c_code(["%s_i" % s for s in _inames],
["%s_i" % s for s in onames], ["%s_i" % s for s in onames],
sub) sub)
task_decl = "".join(["%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals() for name, dtype in zip(inames + list(real_onames), idtypes + list(real_odtypes))]) task_decl = "".join(["%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals() for name, dtype in zip(inames + list(real_onames), idtypes + list(real_odtypes))])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论