提交 bf123427 authored 作者: Frederic Bastien's avatar Frederic Bastien

implement AdvancedIncSubtensor1.infer_shape.

上级 dba003f7
...@@ -4054,6 +4054,10 @@ class AdvancedIncSubtensor1(Op): ...@@ -4054,6 +4054,10 @@ class AdvancedIncSubtensor1(Op):
x[i] += y[j] x[i] += y[j]
out[0] = x out[0] = x
def infer_shape(self, node, ishapes):
x, y, ilist = ishapes
return [x]
advanced_inc_subtensor = AdvancedIncSubtensor1() advanced_inc_subtensor = AdvancedIncSubtensor1()
class AdvancedSubtensor(Op): class AdvancedSubtensor(Op):
......
...@@ -1615,23 +1615,34 @@ class T_subtensor(unittest.TestCase): ...@@ -1615,23 +1615,34 @@ class T_subtensor(unittest.TestCase):
idx_ = shared(numpy.asarray(idx)) idx_ = shared(numpy.asarray(idx))
t = n[idx_] t = n[idx_]
gn = grad(sum(exp(t)), n) gn = grad(sum(exp(t)), n)
f = function([], gn, mode=None) f = function([], [gn, gn.shape], mode=None)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if not fast_compile: if not fast_compile:
assert any([isinstance(node.op, AdvancedIncSubtensor1) and node.op.inplace for node in topo]) assert any([isinstance(node.op, AdvancedIncSubtensor1) and node.op.inplace for node in topo])
else: else:
assert any([isinstance(node.op, AdvancedIncSubtensor1) for node in topo]) assert any([isinstance(node.op, AdvancedIncSubtensor1) for node in topo])
assert any([isinstance(node.op, AdvancedSubtensor1) for node in topo]) assert any([isinstance(node.op, AdvancedSubtensor1) for node in topo])
gval = f() gval, gshape = f()
good = numpy.zeros_like(data) good = numpy.zeros_like(data)
# good[idx] += numpy.exp(data[idx]) don't work when the same index is used many time # good[idx] += numpy.exp(data[idx]) don't work when the same index is used many time
for i in idx: for i in idx:
good[i] += numpy.exp(data[i]) good[i] += numpy.exp(data[i])
self.failUnless(numpy.allclose(gval, good), (gval, good)) self.failUnless(numpy.allclose(gval, good), (gval, good))
self.failUnless(numpy.allclose(gshape, data.shape))
def fct(t): def fct(t):
return sum(exp(t[idx_])) return sum(exp(t[idx_]))
utt.verify_grad(fct, [data]) utt.verify_grad(fct, [data])
if idx is idxs[0]:
f = function([], [gn.shape, n[idx_].shape], mode=None)
topo = f.maker.env.toposort()
if not fast_compile:
self.failUnless(not any([isinstance(node.op, AdvancedIncSubtensor1) for node in topo]))
self.failUnless(not any([isinstance(node.op, AdvancedSubtensor1) for node in topo]))
f()
def test_grad_list(self): def test_grad_list(self):
data = numpy.random.rand(3) data = numpy.random.rand(3)
idxs = [[i] for i in range(data.shape[0])] idxs = [[i] for i in range(data.shape[0])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论