提交 a6a7afb4 authored 作者: Melanie Ducoffe's avatar Melanie Ducoffe

optimization AllocEmpty -> GpuAllocEmpty, no tests

上级 b9ec8993
...@@ -26,7 +26,7 @@ from theano.sandbox.cuda.basic_ops import ( ...@@ -26,7 +26,7 @@ from theano.sandbox.cuda.basic_ops import (
GpuElemwise, GpuDimShuffle, GpuReshape, GpuCAReduce, GpuFlatten, GpuElemwise, GpuDimShuffle, GpuReshape, GpuCAReduce, GpuFlatten,
GpuSubtensor, GpuAdvancedSubtensor1, GpuSubtensor, GpuAdvancedSubtensor1,
GpuAdvancedIncSubtensor1, GpuAdvancedIncSubtensor1_dev20, GpuAdvancedIncSubtensor1, GpuAdvancedIncSubtensor1_dev20,
GpuIncSubtensor, gpu_alloc, GpuAlloc, gpu_shape, GpuSplit) GpuIncSubtensor, gpu_alloc, GpuAlloc, gpu_shape, GpuSplit, GpuAllocEmpty)
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.blas import (gpu_dot22, gpu_dot22scalar, from theano.sandbox.cuda.blas import (gpu_dot22, gpu_dot22scalar,
...@@ -2273,6 +2273,24 @@ def gpuScanOptimization(node): ...@@ -2273,6 +2273,24 @@ def gpuScanOptimization(node):
return outputs return outputs
return False return False
# en attente de tests et de correction
@register_opt()
@local_optimizer([tensor.AllocEmpty, gpu_from_host])
def local_gpu_allocempty(node):
if (isinstance(node.op, tensor.AllocEmpty) and
node.op.dtype=="NPY_FLOAT_32":
if any([(i.owner and isinstance(i.owner.op, HostFromGpu))
for i in node.inputs]):
return [host_from_gpu(GpuAllocEmpty("float32")(gpu_from_host(*node.inputs)))]
if isinstance(node.op, GpuFromHost):
host_input = node.inputs[0]
if (host_input.owner and
isinstance(host_input.owner.op, tensor.AllocEmpty) and
host_input.owner.op.dtype=="NPY_FLOAT_32"):
owner = host_input.owner
return [GpuAllocEmpty()(
gpu_from_host(*owner.inputs))]
optdb.register('gpu_scanOp_make_inplace', optdb.register('gpu_scanOp_make_inplace',
scan_opt.ScanInplaceOptimizer(typeConstructor=typeConstructor, scan_opt.ScanInplaceOptimizer(typeConstructor=typeConstructor,
......
...@@ -5477,7 +5477,7 @@ class AllocEmpty(gof.Op): ...@@ -5477,7 +5477,7 @@ class AllocEmpty(gof.Op):
# specify the type of the data # specify the type of the data
def __init__(self, dtype): def __init__(self, dtype):
assert isinstance(dtype, string) assert isinstance(dtype, str)
self.dtype = 'NPY_' + dtype.upper() self.dtype = 'NPY_' + dtype.upper()
@staticmethod @staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论