提交 a6678feb authored 作者: carriepl's avatar carriepl

Fix index error in scan backend

上级 ea96b166
......@@ -1323,7 +1323,7 @@ class Scan(PureOp):
if var is None:
old_mitmot_input_data[idx] = None
elif self.inps_is_tensor[idx]:
elif self.inps_is_tensor[idx + self.n_seqs]:
old_mitmot_input_data[idx] = var.data
else:
old_mitmot_input_data[idx] = var.gpudata
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -62,7 +62,7 @@ import copy
def get_version():
return 0.292
return 0.293
@cython.boundscheck(False)
def perform(
......@@ -385,7 +385,7 @@ def perform(
if var is None:
old_mitmot_input_data[idx] = None
elif inps_is_tensor[idx]:
elif inps_is_tensor[idx + n_seqs]:
old_mitmot_input_data[idx] = var.data
else:
old_mitmot_input_data[idx] = var.gpudata
......
......@@ -17,7 +17,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.292 # must match constant returned in function get_version()
version = 0.293 # must match constant returned in function get_version()
need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论