提交 2813465d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add optimization to remove single-element joins.

上级 8ed74869
...@@ -21,7 +21,7 @@ from theano.tensor.nnet.conv import ConvOp ...@@ -21,7 +21,7 @@ from theano.tensor.nnet.conv import ConvOp
from theano.sandbox.gpuarray.type import GpuArrayType from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.basic_ops import ( from theano.sandbox.gpuarray.basic_ops import (
host_from_gpu, gpu_from_host, HostFromGpu, host_from_gpu, gpu_from_host, HostFromGpu,
gpu_alloc, GpuAlloc, GpuReshape, GpuEye, gpu_join, gpu_alloc, GpuAlloc, GpuReshape, GpuEye, gpu_join, GpuJoin,
) )
from theano.sandbox.gpuarray.blas import gpu_dot22, GpuGemv, GpuGemm, GpuGer from theano.sandbox.gpuarray.blas import gpu_dot22, GpuGemv, GpuGemm, GpuGer
from theano.sandbox.gpuarray.conv import GpuConv from theano.sandbox.gpuarray.conv import GpuConv
...@@ -273,6 +273,14 @@ def local_gpua_join(node): ...@@ -273,6 +273,14 @@ def local_gpua_join(node):
return gpu_join return gpu_join
@register_opt()
@local_optimizer([GpuJoin])
def local_gpuajoin_1(node):
# join of a single element
if (isinstance(node.op, GpuJoin) and
len(node.inputs) == 2):
return [node.inputs[1]]
@register_opt() @register_opt()
@op_lifter([tensor.Split]) @op_lifter([tensor.Split])
def local_gpua_split(node): def local_gpua_split(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论