提交 c669fc61 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -195,7 +195,8 @@ class Elemwise(Op): ...@@ -195,7 +195,8 @@ class Elemwise(Op):
Elemwise(log)(rand(3, 4, 5)) Elemwise(log)(rand(3, 4, 5))
""" """
def __init__(self, scalar_op, inplace_pattern = {}): def __init__(self, scalar_op, inplace_pattern = {}, name = None):
self.name = name
self.scalar_op = scalar_op self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items()) self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items())
...@@ -238,10 +239,13 @@ class Elemwise(Op): ...@@ -238,10 +239,13 @@ class Elemwise(Op):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __str__(self): def __str__(self):
if self.inplace_pattern: if self.name is None:
return "Broadcast{%s}%s" % (self.scalar_op, str(self.inplace_pattern)) if self.inplace_pattern:
return "Broadcast{%s}%s" % (self.scalar_op, str(self.inplace_pattern))
else:
return "Broadcast{%s}" % (self.scalar_op)
else: else:
return "Broadcast{%s}" % (self.scalar_op) return self.name
def grad(self, inputs, ograds): def grad(self, inputs, ograds):
ograds = map(as_tensor, ograds) # this shouldn't be necessary... ograds = map(as_tensor, ograds) # this shouldn't be necessary...
......
...@@ -255,6 +255,12 @@ class _test_all(unittest.TestCase): ...@@ -255,6 +255,12 @@ class _test_all(unittest.TestCase):
g.replace(e0, new_e0) g.replace(e0, new_e0)
assert g.consistent() assert g.consistent()
# def test_aliased_inputs(self):
# x, y, z = inputs()
# e = add_in_place(x, transpose_view(x))
# g = Env([x], [e], False)
# assert not g.consistent()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -243,9 +243,17 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -243,9 +243,17 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
except AttributeError, AbstractFunctionError: _dmap = {} except AttributeError, AbstractFunctionError: _dmap = {}
vmap = {} vmap = {}
for oidx, iidxs in _vmap.items(): for oidx, iidxs in _vmap.items():
if oidx < 0 or oidx >= node.nout:
raise ValueError("In %s.view_map: output index out of range" % node.op, oidx, _vmap)
if any(iidx < 0 or iidx >= node.nin for iidx in iidxs):
raise ValueError("In %s.view_map: input index out of range" % node.op, iidxs, _vmap)
vmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs] vmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
dmap = {} dmap = {}
for oidx, iidxs in _dmap.items(): for oidx, iidxs in _dmap.items():
if oidx < 0 or oidx >= node.nout:
raise ValueError("In %s.destroy_map: output index out of range" % node.op, oidx, _dmap)
if any(iidx < 0 or iidx >= node.nin for iidx in iidxs):
raise ValueError("In %s.destroy_map: input index out of range" % node.op, iidxs, _dmap)
dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs] dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
return vmap, dmap return vmap, dmap
......
...@@ -398,7 +398,7 @@ s2t.TensorValue = TensorValue ...@@ -398,7 +398,7 @@ s2t.TensorValue = TensorValue
def _elemwise(scalar_op, name): def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op) straight = s2t.Elemwise(scalar_op)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0}) inplace = s2t.Elemwise(inplace_scalar_op, {0: 0}, name = name)
return straight, inplace return straight, inplace
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论