提交 fac6c2b2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix in advanced_inc_subtensor1 if index is broadcastable

上级 a7900bd8
...@@ -1157,7 +1157,11 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -1157,7 +1157,11 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
inplace=inplace, inplace=inplace,
set_instead_of_inc=set_instead_of_inc, set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing) tolerate_inplace_aliasing=tolerate_inplace_aliasing)
return x.owner.op(inner_incsubtensor, *x.owner.inputs[1:]) # The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
elif isinstance(x.owner.op, theano.tensor.Reshape): elif isinstance(x.owner.op, theano.tensor.Reshape):
# This case happens when the indices are not arranged as a vector, but # This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor # as a higher-dimensional array. This is handled by the subtensor
......
...@@ -1325,6 +1325,20 @@ class TestIncSubtensor1(unittest.TestCase): ...@@ -1325,6 +1325,20 @@ class TestIncSubtensor1(unittest.TestCase):
utt.assert_allclose(a2val[2], mval[2] * 3) utt.assert_allclose(a2val[2], mval[2] * 3)
utt.assert_allclose(a2val[3], mval[3] * 2) utt.assert_allclose(a2val[3], mval[3] * 2)
def test_inc_bcastableidx(self):
idx = tensor.constant([0])
c_inc = tensor.col()
m_inc = tensor.matrix()
out1 = inc_subtensor(self.m[:, idx], c_inc)
out2 = inc_subtensor(self.m[:, idx], m_inc)
f = theano.function([self.m, c_inc, m_inc], [out1, out2])
mval = self.rng.random_sample((10, 5))
incval = self.rng.random_sample((10, 1)).astype(config.floatX)
out1val, out2val = f(mval, incval, incval)
utt.assert_allclose(out1val, out2val)
inplace_increment_missing = SkipTest( inplace_increment_missing = SkipTest(
"inc_subtensor with advanced indexing not enabled. " "inc_subtensor with advanced indexing not enabled. "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论