提交 12ff2b84 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

generate cuda memory only when dtype is float32

上级 ddf9cf04
...@@ -383,7 +383,7 @@ def perform( ...@@ -383,7 +383,7 @@ def perform(
outs[j][0].shape[0] < store_steps[j] or outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ): outs[j][0].dtype != dtype ):
if self.gpu: if self.gpu and dtype in ['float32']:
outs[j][0] = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape) outs[j][0] = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
else: else:
outs[j][0] = numpy.zeros(shape, dtype) outs[j][0] = numpy.zeros(shape, dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论