提交 a3fc110c authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: fix optimization to handle matrices

上级 82aac7d9
......@@ -33,7 +33,7 @@ from theano.sandbox.cuda.basic_ops import (
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda.blas import (
gpu_dot22, gpu_dot22scalar, gpu_gemm_inplace, gpu_gemm_no_inplace, GpuConv,
GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights,
BatchedDotOp, GpuCorrMM, GpuCorrMM_gradInputs, GpuCorrMM_gradWeights,
GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights)
from theano.sandbox.cuda.blas import gpu_gemv_inplace
......@@ -619,22 +619,36 @@ def local_gpu_batched_dot(node):
"""
gpu_from_host(batched_dot) -> gpu_batched_dot(gpu_from_host)
batched_dot(host_from_gpu) -> host_from_gpu(batched_dot)
batched_dot(host_from_gpu) -> host_from_gpu(gpu_batched_dot)
"""
def gpu_batched_dot(x, y):
# pad x and y shapes to be third-order tensors
x_, y_ = x, y
if x.ndim == 2:
x_ = x_.dimshuffle(0, "x", 1)
if y.ndim == 2:
y_ = y_.dimshuffle(0, 1, "x")
z = BatchedDotOp()(as_cuda_ndarray_variable(x_),
as_cuda_ndarray_variable(y_))
# unpad z shape
if x.ndim == 2:
z = z.dimshuffle(0, *range(2, z.ndim))
if y.ndim == 2:
z = z.dimshuffle(*range(z.ndim - 1))
return as_cuda_ndarray_variable(z)
"""
if isinstance(node.op, GpuFromHost):
host_input = node.inputs[0]
if host_input.owner and isinstance(host_input.owner.op,
tensor.BatchedDot):
x, y = host_input.owner.inputs
return [batched_dot(as_cuda_ndarray_variable(x),
as_cuda_ndarray_variable(y))]
return [gpu_batched_dot(x, y)]
if isinstance(node.op, tensor.BatchedDot):
if any([(i.owner and isinstance(i.owner.op, HostFromGpu))
for i in node.inputs]):
x, y = node.inputs
return [host_from_gpu(batched_dot(as_cuda_ndarray_variable(x),
as_cuda_ndarray_variable(y)))]
return [host_from_gpu(gpu_batched_dot(x, y))]
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论