提交 bf123427 authored 作者: Frederic Bastien's avatar Frederic Bastien

implement AdvancedIncSubtensor1.infer_shape.

上级 dba003f7
......@@ -4054,6 +4054,10 @@ class AdvancedIncSubtensor1(Op):
x[i] += y[j]
out[0] = x
def infer_shape(self, node, ishapes):
x, y, ilist = ishapes
return [x]
advanced_inc_subtensor = AdvancedIncSubtensor1()
class AdvancedSubtensor(Op):
......
......@@ -1615,23 +1615,34 @@ class T_subtensor(unittest.TestCase):
idx_ = shared(numpy.asarray(idx))
t = n[idx_]
gn = grad(sum(exp(t)), n)
f = function([], gn, mode=None)
f = function([], [gn, gn.shape], mode=None)
topo = f.maker.env.toposort()
if not fast_compile:
assert any([isinstance(node.op, AdvancedIncSubtensor1) and node.op.inplace for node in topo])
else:
assert any([isinstance(node.op, AdvancedIncSubtensor1) for node in topo])
assert any([isinstance(node.op, AdvancedSubtensor1) for node in topo])
gval = f()
gval, gshape = f()
good = numpy.zeros_like(data)
# good[idx] += numpy.exp(data[idx]) don't work when the same index is used many time
for i in idx:
good[i] += numpy.exp(data[i])
self.failUnless(numpy.allclose(gval, good), (gval, good))
self.failUnless(numpy.allclose(gshape, data.shape))
def fct(t):
return sum(exp(t[idx_]))
utt.verify_grad(fct, [data])
if idx is idxs[0]:
f = function([], [gn.shape, n[idx_].shape], mode=None)
topo = f.maker.env.toposort()
if not fast_compile:
self.failUnless(not any([isinstance(node.op, AdvancedIncSubtensor1) for node in topo]))
self.failUnless(not any([isinstance(node.op, AdvancedSubtensor1) for node in topo]))
f()
def test_grad_list(self):
data = numpy.random.rand(3)
idxs = [[i] for i in range(data.shape[0])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论