提交 90845417 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

allocate gpu memery only for float32 outputs

上级 80afdbeb
...@@ -972,7 +972,7 @@ class Scan(PureOp): ...@@ -972,7 +972,7 @@ class Scan(PureOp):
outs[j][0].shape[0] < store_steps[j] or outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype): outs[j][0].dtype != dtype):
if self.gpu: if self.gpu and dtype in ['float32']:
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray _cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
outs[j][0] = _cuda.zeros(shape) outs[j][0] = _cuda.zeros(shape)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论