提交 cdaa8834 authored 作者: nouiz's avatar nouiz

Merge pull request #1274 from lamblin/fix_inc_set_subtensor1

Make inc/set_subtensor work on output of take.
...@@ -5079,6 +5079,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -5079,6 +5079,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
# nor have non-broadcastable dimensions where x is broadcastable. # nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
if y.ndim > x.ndim: if y.ndim > x.ndim:
raise TypeError(("Trying to increment a %d-dimensional " raise TypeError(("Trying to increment a %d-dimensional "
"subtensor with a %d-dimensional value.") % (x.ndim, y.ndim)) "subtensor with a %d-dimensional value.") % (x.ndim, y.ndim))
...@@ -5094,7 +5095,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -5094,7 +5095,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
y = addbroadcast(y, dim) y = addbroadcast(y, dim)
if not x.owner: if not x.owner:
raise TypeError('x must be result of a subtensor operation') raise TypeError('x must be the result of a subtensor operation')
# retrieve idx_list from x.owner # retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor): if isinstance(x.owner.op, Subtensor):
...@@ -5121,8 +5122,38 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -5121,8 +5122,38 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
the_op = AdvancedIncSubtensor(inplace, the_op = AdvancedIncSubtensor(inplace,
set_instead_of_inc=set_instead_of_inc) set_instead_of_inc=set_instead_of_inc)
return the_op(real_x, y, coordvec_0, coordvec_1) return the_op(real_x, y, coordvec_0, coordvec_1)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y).T.
inner_incsubtensor = inc_subtensor(inner_x, y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing)
return x.owner.op(inner_incsubtensor, *x.owner.inputs[1:])
elif isinstance(x.owner.op, Reshape):
inner_x = x.owner.inputs[0]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
inner_incsubtensor = inc_subtensor(inner_x, y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing)
return inner_incsubtensor
else: else:
raise TypeError('x must be result of a subtensor operation') raise TypeError('x must be the result of a subtensor operation')
class IncSubtensor(Op): class IncSubtensor(Op):
......
...@@ -38,7 +38,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -38,7 +38,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp, var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_scalar_constant_value, ivector, reshape, scalar_from_tensor, scal, get_scalar_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, lvector, true_div, max, min, Split, roll, opt, ComplexError, lvector, lmatrix, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
...@@ -3626,6 +3626,53 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -3626,6 +3626,53 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
assert gof.graph.is_same_graph(s1, s2) assert gof.graph.is_same_graph(s1, s2)
def test_adv1_inc_sub_notlastdim(self):
# Test that taking 1-dimensional advanced indexing
# over a dimension that's not the first (outer-most) works.
m = matrix('m')
i = lvector('i')
m1 = set_subtensor(m[:, i], 0)
m2 = inc_subtensor(m[:, i], 1)
f = theano.function([m, i], [m1, m2])
m_val = rand(3, 5)
i_val = randint_ranged(min=0, max=4, shape=(4,))
m1_ref = m_val.copy()
m2_ref = m_val.copy()
m1_val, m2_val = f(m_val, i_val)
for idx in i_val:
m1_ref[:, idx] = 0
m2_ref[:, idx] += 1
assert numpy.allclose(m1_val, m1_ref), (m1_val, m1_ref)
assert numpy.allclose(m2_val, m2_ref), (m2_val, m2_ref)
def test_adv1_inc_sub_notlastdim_2didx(self):
# Test that taking 1-dimensional advanced indexing
# over a dimension that's not the first (outer-most) works,
# if the index is a matrix.
m = matrix('m')
i = lmatrix('i')
m1 = set_subtensor(m[:, i], 0)
m2 = inc_subtensor(m[:, i], 1)
f = theano.function([m, i], [m1, m2])
m_val = rand(5, 7)
i_val = randint_ranged(min=0, max=6, shape=(4, 2))
m1_ref = m_val.copy()
m2_ref = m_val.copy()
m1_val, m2_val = f(m_val, i_val)
for idx in i_val.ravel():
m1_ref[:, idx] = 0
m2_ref[:, idx] += 1
assert numpy.allclose(m1_val, m1_ref), (m1_val, m1_ref)
assert numpy.allclose(m2_val, m2_ref), (m2_val, m2_ref)
class TestIncSubtensor1(unittest.TestCase): class TestIncSubtensor1(unittest.TestCase):
# test inc_subtensor # test inc_subtensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论