提交 306c8bc7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add optimization to replace CAReduce by their GPU versions.

上级 4d530ed0
...@@ -12,7 +12,7 @@ from theano.sandbox.gpuarray.type import GpuArrayType ...@@ -12,7 +12,7 @@ from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.basic_ops import (host_from_gpu, gpu_from_host, from theano.sandbox.gpuarray.basic_ops import (host_from_gpu, gpu_from_host,
gpu_alloc) gpu_alloc)
from theano.sandbox.gpuarray.elemwise import (GpuElemwise, _is_scalar, from theano.sandbox.gpuarray.elemwise import (GpuElemwise, _is_scalar,
GpuDimShuffle) GpuDimShuffle, GpuCAReduce)
from theano.sandbox.gpuarray.subtensor import GpuSubtensor from theano.sandbox.gpuarray.subtensor import GpuSubtensor
gpu_optimizer = EquilibriumDB() gpu_optimizer = EquilibriumDB()
...@@ -170,3 +170,10 @@ def local_gpua_specifyShape(node): ...@@ -170,3 +170,10 @@ def local_gpua_specifyShape(node):
@op_lifter(tensor.Subtensor) @op_lifter(tensor.Subtensor)
def local_gpua_subtensor(node): def local_gpua_subtensor(node):
return GpuSubtensor(node.op.idx_list) return GpuSubtensor(node.op.idx_list)
@register_opt()
@op_lifter(tensor.CAReduce)
def local_gpua_careduce(node):
return GpuCAReduce(node.op.scalar_op, axis=node.op.axis,
dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论