提交 d721f2de authored 作者: Frederic Bastien's avatar Frederic Bastien

make and optimization to have the grad of AdvancedSubtensor1(AdvancedIncSubtensor1) work inplace.

上级 edeaf4a8
...@@ -4012,6 +4012,10 @@ advanced_subtensor1 = AdvancedSubtensor1() ...@@ -4012,6 +4012,10 @@ advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(Op): class AdvancedIncSubtensor1(Op):
"""Increments a subtensor using advanced slicing (list of index)""" """Increments a subtensor using advanced slicing (list of index)"""
def __init__(self, inplace=False):
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -4042,6 +4046,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -4042,6 +4046,7 @@ class AdvancedIncSubtensor1(Op):
# TODO opt to make this inplace # TODO opt to make this inplace
x, y, idx = inp x, y, idx = inp
out, = out_ out, = out_
if not self.inplace:
x = x.copy() x = x.copy()
# x[idx] += y don't work if the same index is present many times. # x[idx] += y don't work if the same index is present many times.
# It do it only once # It do it only once
......
...@@ -1158,6 +1158,18 @@ def local_inplace_setsubtensor(node): ...@@ -1158,6 +1158,18 @@ def local_inplace_setsubtensor(node):
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor, compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
@gof.local_optimizer([None])
def local_inplace_incsubtensor1(node):
"""
Also work for GpuIncSubtensor
"""
if isinstance(node.op, T.AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.__class__(inplace=True)
new_node = new_op(*node.inputs)
return [new_node]
return False
compile.optdb.register('local_inplace_incsubtensor1', TopoOptimizer(local_inplace_incsubtensor1,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
#################### ####################
# Rebroadcast opts # # Rebroadcast opts #
......
...@@ -1616,7 +1616,7 @@ class T_subtensor(unittest.TestCase): ...@@ -1616,7 +1616,7 @@ class T_subtensor(unittest.TestCase):
gn = grad(sum(exp(t)), n) gn = grad(sum(exp(t)), n)
f = function([], gn, mode=None) f = function([], gn, mode=None)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
assert any([isinstance(node.op, AdvancedIncSubtensor1) for node in topo]) assert any([isinstance(node.op, AdvancedIncSubtensor1) and node.op.inplace for node in topo])
assert any([isinstance(node.op, AdvancedSubtensor1) for node in topo]) assert any([isinstance(node.op, AdvancedSubtensor1) for node in topo])
gval = f() gval = f()
good = numpy.zeros_like(data) good = numpy.zeros_like(data)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论