提交 e6777e0b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix elusive memory-corruption bug in Scan.

The problem was that when copying a part of an array onto an overlapping part of the same array, sometimes elements were overwritten before being copied, leading to silent data corruption.
上级 efade648
...@@ -2,9 +2,13 @@ ...@@ -2,9 +2,13 @@
Updates in the Trunk since the last release: Updates in the Trunk since the last release:
Sparse Sandbox bugfix Bug fixes
* Fix the grad of theano.sparse.sandbox.sp.row_scale. It didn't * Outputs of Scan nodes could contain corrupted values: some parts of the
returned the right number of element. (Frederic B.) output would be repeated a second time, instead of the correct values.
It happened randomly, and quite infrequently, but the bug has been present
(both in Python and Cython) since April 2011. (Pascal L.)
* In Sparse sandbox, fix the grad of theano.sparse.sandbox.sp.row_scale.
It did not return the right number of elements. (Frederic B.)
Documentation Documentation
* Added in the tutorial documentation on how to extend Theano. * Added in the tutorial documentation on how to extend Theano.
......
...@@ -809,7 +809,6 @@ class Scan(PureOp): ...@@ -809,7 +809,6 @@ class Scan(PureOp):
############## THE MAIN LOOP ######################### ############## THE MAIN LOOP #########################
#for i in xrange(n_steps): #for i in xrange(n_steps):
while (i < n_steps) and cond: while (i < n_steps) and cond:
# sequences over which scan iterates # sequences over which scan iterates
# 3. collect input slices # 3. collect input slices
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
...@@ -955,12 +954,17 @@ class Scan(PureOp): ...@@ -955,12 +954,17 @@ class Scan(PureOp):
begin = self.n_mit_mot begin = self.n_mit_mot
end = self.n_outs + self.n_nit_sot end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end): for idx in xrange(begin, end):
min_tap = self.mintaps[idx]
if (store_steps[idx] < i - self.mintaps[idx] and if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx]): pos[idx] < store_steps[idx]):
pdx = pos[idx] pdx = pos[idx]
if pdx < store_steps[idx] // 2: if pdx >= store_steps[idx] // 2:
# It seems inefficient to copy the bigger part of the
# array over, and back, but it is the only way that
# there is no overlap in the areas of out[idx][0] that
# are read and written.
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,) + outs[idx][0].shape[1:] shape = (pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance(outs[idx][0], if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray): cuda.CudaNdarray):
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -63,7 +63,7 @@ from theano.sandbox import cuda ...@@ -63,7 +63,7 @@ from theano.sandbox import cuda
def get_version(): def get_version():
return 0.265 return 0.266
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -418,7 +418,13 @@ def perform( ...@@ -418,7 +418,13 @@ def perform(
pos[idx] < store_steps[idx] ): pos[idx] < store_steps[idx] ):
pdx = pos[idx] pdx = pos[idx]
if pdx < store_steps[idx]//2 : if pdx >= store_steps[idx]//2 :
# It seems inefficient to copy the bigger part of the
# array over, and back, but it is the only way that
# there is no overlap in the areas of out[idx][0] that
# are read and written.
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,)+ outs[idx][0].shape[1:] shape = (pdx,)+ outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance( outs[idx][0], if cuda.cuda_available and isinstance( outs[idx][0],
......
...@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.DEBUG)
if config.compiledir not in sys.path: if config.compiledir not in sys.path:
sys.path.append(config.compiledir) sys.path.append(config.compiledir)
version = 0.265 # must match constant returned in function get_version() version = 0.266 # must match constant returned in function get_version()
need_reload = False need_reload = False
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论