提交 b1ceaa95 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use GpuAllocEmpty for Gemm16 output.

上级 5a9fcbce
...@@ -9,7 +9,7 @@ from theano.scalar import as_scalar, constant ...@@ -9,7 +9,7 @@ from theano.scalar import as_scalar, constant
from . import opt from . import opt
from .basic_ops import (as_gpuarray_variable, gpu_alloc, gpu_from_host, from .basic_ops import (as_gpuarray_variable, gpu_alloc, gpu_from_host,
host_from_gpu) host_from_gpu, GpuAllocEmpty)
from .opt_util import alpha_merge, output_merge from .opt_util import alpha_merge, output_merge
from .pycuda_helper import ensure_pycuda_context from .pycuda_helper import ensure_pycuda_context
...@@ -100,7 +100,7 @@ def local_dot_to_gemm16(node): ...@@ -100,7 +100,7 @@ def local_dot_to_gemm16(node):
fgraph = node.inputs[0].fgraph fgraph = node.inputs[0].fgraph
A = gpu_from_host(node.inputs[0]) A = gpu_from_host(node.inputs[0])
B = gpu_from_host(node.inputs[1]) B = gpu_from_host(node.inputs[1])
C = gpu_alloc(numpy.asarray(0, dtype='float16'), C = GpuAllocEmpty(dtype='float16')(
shape_i(A, 0, fgraph), shape_i(B, 1, fgraph)) shape_i(A, 0, fgraph), shape_i(B, 1, fgraph))
return [host_from_gpu(Gemm16()(C, 1.0, A, B, 0.0))] return [host_from_gpu(Gemm16()(C, 1.0, A, B, 0.0))]
...@@ -121,7 +121,13 @@ def local_gemm16_output_merge(node, *inputs): ...@@ -121,7 +121,13 @@ def local_gemm16_output_merge(node, *inputs):
def local_gemm16_inplace(node): def local_gemm16_inplace(node):
if type(node.op) != Gemm16 or node.op.inplace: if type(node.op) != Gemm16 or node.op.inplace:
return return
return [Gemm16(relu=node.op.relu, inplace=True)(*node.inputs)] inputs = list(node.inputs)
C = inputs[0]
if (C.owner and
isinstance(C.owner.op, GpuAllocEmpty) and
len(C.clients) > 1):
inputs[0] = C.owner.op(*C.owner.inputs)
return [Gemm16(relu=node.op.relu, inplace=True)(*inputs)]
optdb.register('local_gemm16_inplace', optdb.register('local_gemm16_inplace',
tensor.opt.in2out(local_gemm16_inplace, tensor.opt.in2out(local_gemm16_inplace,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论