提交 4e0fcb7d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix NumPy dtype conversion issues in Scan.perform

上级 330f7af4
...@@ -1967,7 +1967,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1967,7 +1967,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
or output_storage[j][0].dtype != dtype or output_storage[j][0].dtype != dtype
): ):
output_storage[j][0] = np.empty( output_storage[j][0] = np.empty(
shape, dtype=node.outputs[j].type shape, dtype=node.outputs[j].type.dtype
) )
elif output_storage[j][0].shape[0] != store_steps[j]: elif output_storage[j][0].shape[0] != store_steps[j]:
output_storage[j][0] = output_storage[j][0][: store_steps[j]] output_storage[j][0] = output_storage[j][0][: store_steps[j]]
...@@ -2025,7 +2025,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2025,7 +2025,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,) + output_storage[idx][0].shape[1:] shape = (pdx,) + output_storage[idx][0].shape[1:]
tmp = np.empty(shape, dtype=node.outputs[idx].type) tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
tmp[:] = output_storage[idx][0][:pdx] tmp[:] = output_storage[idx][0][:pdx]
output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[ output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[
idx idx
...@@ -2034,7 +2034,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2034,7 +2034,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
del tmp del tmp
else: else:
shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:] shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:]
tmp = np.empty(shape, dtype=node.outputs[idx].type) tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
tmp[:] = output_storage[idx][0][pdx:] tmp[:] = output_storage[idx][0][pdx:]
output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[ output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[
idx idx
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论