提交 53532957 authored 作者: abergeron's avatar abergeron

Merge pull request #1993 from nouiz/crash_inc_sub_grad

Crash fix in IncSubtensor.grad
...@@ -594,8 +594,7 @@ class Subtensor(Op): ...@@ -594,8 +594,7 @@ class Subtensor(Op):
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim, def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
c_prefix=None, c_prefix=None,
strides_mul=None, strides_mul=None):
):
""" """
The parameters c_prefix are there to allow reusing this The parameters c_prefix are there to allow reusing this
function on PyArray and CudaNdarray object. function on PyArray and CudaNdarray object.
...@@ -682,7 +681,8 @@ class Subtensor(Op): ...@@ -682,7 +681,8 @@ class Subtensor(Op):
subensor_spec = "npy_intp * subtensor_spec = NULL;" subensor_spec = "npy_intp * subtensor_spec = NULL;"
if is_slice: if is_slice:
is_slice_init = "int is_slice[] = {" + ",".join([str(s) for s in is_slice]) + "};" is_slice_init = "int is_slice[] = {" + ",".join([str(s) for s in
is_slice]) + "};"
else: else:
is_slice_init = "int* is_slice = NULL;" is_slice_init = "int* is_slice = NULL;"
subtensor_init = "\n".join(init_cmds) subtensor_init = "\n".join(init_cmds)
...@@ -897,7 +897,8 @@ class Subtensor(Op): ...@@ -897,7 +897,8 @@ class Subtensor(Op):
%(z)s = xview; %(z)s = xview;
""" % locals() """ % locals()
return decl + checkNDim + "{" + get_xview + build_view + finish_view + "}" return (decl + checkNDim +
"{" + get_xview + build_view + finish_view + "}")
def c_code_cache_version(self): def c_code_cache_version(self):
hv = self.helper_c_code_cache_version() hv = self.helper_c_code_cache_version()
...@@ -1022,9 +1023,9 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -1022,9 +1023,9 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
destroyhandler_tolerate_aliased = [[0, 1]] destroyhandler_tolerate_aliased = [[0, 1]]
else: else:
destroyhandler_tolerate_aliased = [] destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc, the_op = IncSubtensor(
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased x.owner.op.idx_list, inplace, set_instead_of_inc,
) destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased)
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:] real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs) return the_op(real_x, y, *real_idxargs)
...@@ -1105,10 +1106,10 @@ class IncSubtensor(Op): ...@@ -1105,10 +1106,10 @@ class IncSubtensor(Op):
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) \ return (type(self) == type(other) and
and self.idx_list == other.idx_list \ self.idx_list == other.idx_list and
and self.inplace == other.inplace \ self.inplace == other.inplace and
and self.set_instead_of_inc == other.set_instead_of_inc self.set_instead_of_inc == other.set_instead_of_inc)
def __hash__(self): def __hash__(self):
msg = [] msg = []
...@@ -1120,12 +1121,12 @@ class IncSubtensor(Op): ...@@ -1120,12 +1121,12 @@ class IncSubtensor(Op):
idx_list = tuple(msg) idx_list = tuple(msg)
# backport # backport
#idx_list = tuple((entry.start, entry.stop, entry.step) # idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice) # if isinstance(entry, slice)
# else entry # else entry
# for entry in self.idx_list) # for entry in self.idx_list)
return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) \ return (hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) ^
^ hash(self.set_instead_of_inc) hash(self.set_instead_of_inc))
def __str__(self): def __str__(self):
indices = [] indices = []
...@@ -1225,7 +1226,7 @@ class IncSubtensor(Op): ...@@ -1225,7 +1226,7 @@ class IncSubtensor(Op):
if not self.set_instead_of_inc: if not self.set_instead_of_inc:
sub_x += y sub_x += y
else: else:
#sub_x += -sub_x + y # sub_x += -sub_x + y
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
else: else:
# scalar case # scalar case
...@@ -1295,7 +1296,7 @@ class IncSubtensor(Op): ...@@ -1295,7 +1296,7 @@ class IncSubtensor(Op):
** helper_args ** helper_args
) )
#Make a view on the output, as we will write into it. # Make a view on the output, as we will write into it.
alloc_zview = self.make_view_array(z, view_ndim) alloc_zview = self.make_view_array(z, view_ndim)
build_view = """ build_view = """
...@@ -1460,7 +1461,8 @@ class IncSubtensor(Op): ...@@ -1460,7 +1461,8 @@ class IncSubtensor(Op):
gx = g_output gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
if gy.broadcastable != y.broadcastable: if gy.broadcastable != y.broadcastable:
y_broad = (True,) * (gy.ndim - y.ndim) + y.broadcastable y_dim_added = gy.ndim - y.ndim
y_broad = (True,) * y_dim_added + y.broadcastable
assert sum(gy.broadcastable) < sum(y_broad) assert sum(gy.broadcastable) < sum(y_broad)
axis_to_sum = [] axis_to_sum = []
for i in range(gy.ndim): for i in range(gy.ndim):
...@@ -1468,7 +1470,7 @@ class IncSubtensor(Op): ...@@ -1468,7 +1470,7 @@ class IncSubtensor(Op):
axis_to_sum.append(i) axis_to_sum.append(i)
elif (gy.broadcastable[i] is True and elif (gy.broadcastable[i] is True and
y_broad[i] is False): y_broad[i] is False):
# This mean that THeano where able to infer that # This mean that Theano where able to infer that
# gy.shape[i] is 1, so y.shape[i] is 1, but we # gy.shape[i] is 1, so y.shape[i] is 1, but we
# didn't know it. It is fine. # didn't know it. It is fine.
pass pass
...@@ -1476,7 +1478,10 @@ class IncSubtensor(Op): ...@@ -1476,7 +1478,10 @@ class IncSubtensor(Op):
assert gy.broadcastable[i] == y_broad[i] assert gy.broadcastable[i] == y_broad[i]
gy = gy.sum(axis=axis_to_sum, keepdims=True) gy = gy.sum(axis=axis_to_sum, keepdims=True)
if gy.ndim != y.ndim: if gy.ndim != y.ndim:
gy = gy.dimshuffle(*range(y.ndim, gy.ndim)) assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable assert gy.broadcastable == y.broadcastable
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
...@@ -1540,8 +1545,9 @@ class AdvancedSubtensor1(Op): ...@@ -1540,8 +1545,9 @@ class AdvancedSubtensor1(Op):
if not numpy.can_cast(i.dtype, numpy.intp): if not numpy.can_cast(i.dtype, numpy.intp):
# Check if there was actually an incorrect conversion # Check if there was actually an incorrect conversion
if numpy.any(i != i_): if numpy.any(i != i_):
raise IndexError('index contains values that are bigger ' raise IndexError(
'than the maximum array size on this system.', i) 'index contains values that are bigger '
'than the maximum array size on this system.', i)
i = i_ i = i_
out[0] = x.take(i, axis=0, out=o) out[0] = x.take(i, axis=0, out=o)
...@@ -1732,9 +1738,10 @@ class AdvancedIncSubtensor1(Op): ...@@ -1732,9 +1738,10 @@ class AdvancedIncSubtensor1(Op):
opname = 'set' opname = 'set'
else: else:
opname = 'increment' opname = 'increment'
raise TypeError('cannot %s x subtensor with ndim=%s' raise TypeError(
' by y with ndim=%s to x subtensor with ndim=%s ' % ( 'cannot %s x subtensor with ndim=%s'
opname, x_.type.ndim, y_.type.ndim)) ' by y with ndim=%s to x subtensor with ndim=%s ' % (
opname, x_.type.ndim, y_.type.ndim))
return Apply(self, [x_, y_, ilist_], [x_.type()]) return Apply(self, [x_, y_, ilist_], [x_.type()])
...@@ -1837,7 +1844,7 @@ def adv_index_broadcastable_pattern(a, idx): ...@@ -1837,7 +1844,7 @@ def adv_index_broadcastable_pattern(a, idx):
newidx = tuple(map(replace_slice, idx)) newidx = tuple(map(replace_slice, idx))
#2 - True = 1; 2 - False = 2 # 2 - True = 1; 2 - False = 2
fakeshape = [2 - bc for bc in a.broadcastable] fakeshape = [2 - bc for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape]) return tuple([dim == 1 for dim in retshape])
...@@ -1867,7 +1874,7 @@ class AdvancedSubtensor(Op): ...@@ -1867,7 +1874,7 @@ class AdvancedSubtensor(Op):
return gof.Apply(self, return gof.Apply(self,
(x,) + index, (x,) + index,
[theano.tensor.tensor(dtype=x.type.dtype, [theano.tensor.tensor(dtype=x.type.dtype,
broadcastable=bcast)]) broadcastable=bcast)])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -1897,13 +1904,11 @@ class AdvancedSubtensor(Op): ...@@ -1897,13 +1904,11 @@ class AdvancedSubtensor(Op):
if (numpy.__version__ <= '1.6.1' and if (numpy.__version__ <= '1.6.1' and
out[0].size != numpy.uint32(out[0].size)): out[0].size != numpy.uint32(out[0].size)):
warnings.warn( warnings.warn(
'Numpy versions 1.6.1 and below have a bug preventing ' 'Numpy versions 1.6.1 and below have a bug preventing '
'advanced indexing from correctly filling arrays that ' 'advanced indexing from correctly filling arrays that '
'are too big (>= 2^32 elements). It is possible that ' 'are too big (>= 2^32 elements). It is possible that '
'out[0] (%s), with shape %s, is not correctly filled.' 'out[0] (%s), with shape %s, is not correctly filled.'
% (out[0], out[0].shape)) % (out[0], out[0].shape))
# return
#raise NotImplementedError()
def connection_pattern(self, node): def connection_pattern(self, node):
...@@ -1955,8 +1960,9 @@ class AdvancedIncSubtensor(Op): ...@@ -1955,8 +1960,9 @@ class AdvancedIncSubtensor(Op):
def __str__(self): def __str__(self):
return "%s{%s, %s}" % (self.__class__.__name__, return "%s{%s, %s}" % (self.__class__.__name__,
"inplace=" + str(self.inplace), "inplace=" + str(self.inplace),
" set_instead_of_inc=" + str(self. set_instead_of_inc)) " set_instead_of_inc=" +
str(self. set_instead_of_inc))
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x = theano.tensor.as_tensor_variable(x) x = theano.tensor.as_tensor_variable(x)
...@@ -1990,17 +1996,18 @@ class AdvancedIncSubtensor(Op): ...@@ -1990,17 +1996,18 @@ class AdvancedIncSubtensor(Op):
op.allow_legacy_perform = True op.allow_legacy_perform = True
else: else:
raise NotImplementedError( raise NotImplementedError(
'Could not import inplace_increment, so some advanced ' 'Could not import inplace_increment, so some advanced '
'indexing features are disabled. They will be ' 'indexing features are disabled. They will be '
'available if you update NumPy to version 1.8 or ' 'available if you update NumPy to version 1.8 or '
'later, or to the latest development version. ' 'later, or to the latest development version. '
'You may need to clear the cache (theano-cache clear) ' 'You may need to clear the cache (theano-cache clear) '
'afterwards.') 'afterwards.')
return gof.Apply(op, return gof.Apply(op,
(x, y) + inputs, (x, y) + inputs,
[theano.tensor.tensor(dtype=x.type.dtype, [theano.tensor.tensor(
broadcastable=x.type.broadcastable)]) dtype=x.type.dtype,
broadcastable=x.type.broadcastable)])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
# TODO: 1. opt to make this in place 2. generalize as described in # TODO: 1. opt to make this in place 2. generalize as described in
...@@ -2020,21 +2027,21 @@ class AdvancedIncSubtensor(Op): ...@@ -2020,21 +2027,21 @@ class AdvancedIncSubtensor(Op):
out[0][inputs[2:]] += inputs[1] out[0][inputs[2:]] += inputs[1]
else: else:
raise NotImplementedError( raise NotImplementedError(
'Could not import inplace_increment, so some advanced ' 'Could not import inplace_increment, so some advanced '
'indexing features are disabled. They will be ' 'indexing features are disabled. They will be '
'available if you update NumPy to version 1.8 or ' 'available if you update NumPy to version 1.8 or '
'later, or to the latest development version. ' 'later, or to the latest development version. '
'You may need to clear the cache (theano-cache clear) ' 'You may need to clear the cache (theano-cache clear) '
'afterwards.') 'afterwards.')
if (numpy.__version__ <= '1.6.1' and if (numpy.__version__ <= '1.6.1' and
out[0].size != numpy.uint32(out[0].size)): out[0].size != numpy.uint32(out[0].size)):
warnings.warn( warnings.warn(
'Numpy versions 1.6.1 and below have a bug preventing ' 'Numpy versions 1.6.1 and below have a bug preventing '
'advanced indexing from correctly filling arrays that ' 'advanced indexing from correctly filling arrays that '
'are too big (>= 2^32 elements). It is possible that ' 'are too big (>= 2^32 elements). It is possible that '
'out[0] (%s), with shape %s, is not correctly filled.' 'out[0] (%s), with shape %s, is not correctly filled.'
% (out[0], out[0].shape)) % (out[0], out[0].shape))
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
return [ishapes[0]] return [ishapes[0]]
...@@ -2092,6 +2099,6 @@ def take(a, indices, axis=None, mode='raise'): ...@@ -2092,6 +2099,6 @@ def take(a, indices, axis=None, mode='raise'):
ndim = indices.ndim ndim = indices.ndim
else: else:
shape = theano.tensor.concatenate( shape = theano.tensor.concatenate(
[a.shape[:axis], indices.shape, a.shape[axis + 1:]]) [a.shape[:axis], indices.shape, a.shape[axis + 1:]])
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim) return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
...@@ -83,11 +83,11 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -83,11 +83,11 @@ class Test_inc_subtensor(unittest.TestCase):
f(rng_randX(3, 1), rng_randX(1)) f(rng_randX(3, 1), rng_randX(1))
# These ones should not # These ones should not
self.assertRaises(ValueError, self.assertRaises(ValueError,
f, rng_randX(3, 1), rng_randX(2)) f, rng_randX(3, 1), rng_randX(2))
self.assertRaises(ValueError, self.assertRaises(ValueError,
f, rng_randX(3, 1), rng_randX(3)) f, rng_randX(3, 1), rng_randX(3))
self.assertRaises(ValueError, self.assertRaises(ValueError,
f, rng_randX(3, 1), rng_randX(0)) f, rng_randX(3, 1), rng_randX(0))
def test_simple_3d(self): def test_simple_3d(self):
"""Increments or sets part of a tensor by a scalar using full slice and """Increments or sets part of a tensor by a scalar using full slice and
...@@ -100,30 +100,42 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -100,30 +100,42 @@ class Test_inc_subtensor(unittest.TestCase):
sl2 = slice(sl2_end) sl2 = slice(sl2_end)
sl3 = 2 sl3 = 2
for do_set in [True, False]: val_a = numpy.ones((5, 3, 4))
print "Set", do_set val_inc = 2.3
val_sl2_end = 2
if do_set: for method in [tt.set_subtensor, tt.inc_subtensor]:
resut = tt.set_subtensor(a[sl1, sl3, sl2], increment) print "MethodSet", method
else:
resut = tt.inc_subtensor(a[sl1, sl3, sl2], increment)
f = theano.function([a, increment, sl2_end], resut) resut = method(a[sl1, sl3, sl2], increment)
val_a = numpy.ones((5, 3, 4)) f = theano.function([a, increment, sl2_end], resut)
val_inc = 2.3
val_sl2_end = 2
expected_result = numpy.copy(val_a) expected_result = numpy.copy(val_a)
result = f(val_a, val_inc, val_sl2_end) result = f(val_a, val_inc, val_sl2_end)
if do_set: if method is tt.set_subtensor:
expected_result[:, sl3, :val_sl2_end] = val_inc expected_result[:, sl3, :val_sl2_end] = val_inc
else: else:
expected_result[:, sl3, :val_sl2_end] += val_inc expected_result[:, sl3, :val_sl2_end] += val_inc
utt.assert_allclose(result, expected_result) utt.assert_allclose(result, expected_result)
# Test when we broadcast the result
resut = method(a[sl1, sl2], increment)
f = theano.function([a, increment, sl2_end], resut)
expected_result = numpy.copy(val_a)
result = f(val_a, val_inc, val_sl2_end)
if method is tt.set_subtensor:
expected_result[:, :val_sl2_end] = val_inc
else:
expected_result[:, :val_sl2_end] += val_inc
utt.assert_allclose(result, expected_result)
def test_grad_inc_set(self): def test_grad_inc_set(self):
def inc_slice(*s): def inc_slice(*s):
def just_numeric_args(a, b): def just_numeric_args(a, b):
...@@ -138,19 +150,24 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -138,19 +150,24 @@ class Test_inc_subtensor(unittest.TestCase):
for f_slice in [inc_slice, set_slice]: for f_slice in [inc_slice, set_slice]:
# vector # vector
utt.verify_grad( utt.verify_grad(
f_slice(slice(2, 4, None)), f_slice(slice(2, 4, None)),
(numpy.asarray([0, 1, 2, 3, 4, 5.]), (numpy.asarray([0, 1, 2, 3, 4, 5.]),
numpy.asarray([9, 9.]), )) numpy.asarray([9, 9.]), ))
# matrix # matrix
utt.verify_grad( utt.verify_grad(
f_slice(slice(1, 2, None), slice(None, None, None)), f_slice(slice(1, 2, None), slice(None, None, None)),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]), (numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray([[9, 9.]]), )) numpy.asarray([[9, 9.]]), ))
#single element # single element
utt.verify_grad( utt.verify_grad(
f_slice(2, 1), f_slice(2, 1),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]), (numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray(9.),)) numpy.asarray(9.),))
# broadcast
utt.verify_grad(
f_slice(2),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray(9.),))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论