提交 fcf84d8d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

remove explicit call to cuda

上级 4e78cf82
...@@ -384,10 +384,7 @@ def perform( ...@@ -384,10 +384,7 @@ def perform(
outs[j][0].shape[0] < store_steps[j] or outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ): outs[j][0].dtype != dtype ):
if self.gpu and dtype in ['float32']: outs[j][0] = node.outputs[j].type.value_zeros(shape)
outs[j][0] = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
else:
outs[j][0] = numpy.zeros(shape, dtype)
elif outs[j][0].shape[0] != store_steps[j]: elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]] outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0] outs[j][0][pos[j]] = output_storage[jout].storage[0]
...@@ -427,22 +424,13 @@ def perform( ...@@ -427,22 +424,13 @@ def perform(
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,)+ outs[idx][0].shape[1:] shape = (pdx,)+ outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance( outs[idx][0], tmp = node.outputs[idx].type.value_zeros(shape)
cuda.CudaNdarray):
tmp = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
else:
tmp = numpy.empty(shape, outs[idx][0].dtype)
tmp[:] = outs[idx][0][:pdx] tmp[:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = outs[idx][0][pdx:] outs[idx][0][:store_steps[idx]-pdx] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = tmp outs[idx][0][store_steps[idx]-pdx:] = tmp
else: else:
shape = (store_steps[idx]-pdx,) + outs[idx][0].shape[1:] shape = (store_steps[idx]-pdx,) + outs[idx][0].shape[1:]
tmp = node.outputs[idx].type.value_zeros(shape)
if cuda.cuda_available and isinstance( outs[idx][0],
cuda.CudaNdarray):
tmp = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(shape)
else:
tmp = numpy.empty(shape, outs[idx][0].dtype)
tmp[:] = outs[idx][0][pdx:] tmp[:] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx]-pdx:] = outs[idx][0][:pdx] outs[idx][0][store_steps[idx]-pdx:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx]-pdx] = tmp outs[idx][0][:store_steps[idx]-pdx] = tmp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论