提交 a6678feb authored 作者: carriepl's avatar carriepl

Fix index error in scan backend

上级 ea96b166
...@@ -1323,7 +1323,7 @@ class Scan(PureOp): ...@@ -1323,7 +1323,7 @@ class Scan(PureOp):
if var is None: if var is None:
old_mitmot_input_data[idx] = None old_mitmot_input_data[idx] = None
elif self.inps_is_tensor[idx]: elif self.inps_is_tensor[idx + self.n_seqs]:
old_mitmot_input_data[idx] = var.data old_mitmot_input_data[idx] = var.data
else: else:
old_mitmot_input_data[idx] = var.gpudata old_mitmot_input_data[idx] = var.gpudata
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ import copy ...@@ -62,7 +62,7 @@ import copy
def get_version(): def get_version():
return 0.292 return 0.293
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -385,7 +385,7 @@ def perform( ...@@ -385,7 +385,7 @@ def perform(
if var is None: if var is None:
old_mitmot_input_data[idx] = None old_mitmot_input_data[idx] = None
elif inps_is_tensor[idx]: elif inps_is_tensor[idx + n_seqs]:
old_mitmot_input_data[idx] = var.data old_mitmot_input_data[idx] = var.data
else: else:
old_mitmot_input_data[idx] = var.gpudata old_mitmot_input_data[idx] = var.gpudata
......
...@@ -17,7 +17,7 @@ from theano.gof import cmodule ...@@ -17,7 +17,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform') _logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.292 # must match constant returned in function get_version() version = 0.293 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论