提交 1df9dacd authored 作者: Razvan Pascanu's avatar Razvan Pascanu

removed explicit call to cuda.zeros

上级 8d095a6d
......@@ -995,14 +995,10 @@ class Scan(PureOp):
self.vector_outs[j] = True
dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype):
if self.gpu and dtype in ['float32']:
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
outs[j][0] = _cuda.zeros(shape)
else:
outs[j][0] = numpy.zeros(shape, dtype)
outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype):
outs[j][0] = node.outputs[j].type.value_zeros(shape)
elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0]
......@@ -1040,24 +1036,14 @@ class Scan(PureOp):
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray):
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
tmp = _cuda.zeros(shape)
else:
tmp = numpy.empty(shape)
tmp = node.outputs[idx].type.value_zeros(shape)
tmp[:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx] - pdx] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx] - pdx:] = tmp
del tmp
else:
shape = (store_steps[idx] - pdx,) + outs[idx][0].shape[1:]
if cuda.cuda_available and isinstance(outs[idx][0],
cuda.CudaNdarray):
_cuda = cuda.cuda_ndarray.cuda_ndarray.CudaNdarray
tmp = _cuda.zeros(shape)
else:
tmp = numpy.empty(shape)
tmp = node.outputs[idx].type.value_zeros(shape)
tmp[:] = outs[idx][0][pdx:]
outs[idx][0][store_steps[idx] - pdx:] = outs[idx][0][:pdx]
outs[idx][0][:store_steps[idx] - pdx] = tmp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论