提交 e971d93b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use op_lifter for Gemm16

上级 e3f69de9
...@@ -8,8 +8,7 @@ from theano.gof import local_optimizer, COp ...@@ -8,8 +8,7 @@ from theano.gof import local_optimizer, COp
from theano.scalar import as_scalar, constant from theano.scalar import as_scalar, constant
from . import opt from . import opt
from .basic_ops import (as_gpuarray_variable, gpu_from_host, from .basic_ops import (as_gpuarray_variable, GpuAllocEmpty)
host_from_gpu, GpuAllocEmpty)
from .opt_util import alpha_merge, output_merge from .opt_util import alpha_merge, output_merge
from .pycuda_helper import ensure_pycuda_context from .pycuda_helper import ensure_pycuda_context
...@@ -161,18 +160,16 @@ if (GpuKernel_init(&k_%(name)s, c->ops, c->ctx, 1, &bcode, &sz, ...@@ -161,18 +160,16 @@ if (GpuKernel_init(&k_%(name)s, c->ops, c->ctx, 1, &bcode, &sz,
@opt.register_opt() @opt.register_opt()
@local_optimizer([tensor.Dot]) @opt.op_lifter([tensor.Dot])
def local_dot_to_gemm16(node): def local_dot_to_gemm16(node):
if (type(node.op) == tensor.Dot and A = node.inputs[0]
node.inputs[0].dtype == 'float16' and B = node.inputs[1]
node.inputs[1].dtype == 'float16' and if (A.ndim == 2 and B.ndim == 2 and
node.inputs[0].ndim == 2 and node.inputs[1].ndim == 2): A.dtype == 'float16' and B.dtype == 'float16'):
fgraph = node.inputs[0].fgraph fgraph = node.inputs[0].fgraph
A = gpu_from_host(node.inputs[0])
B = gpu_from_host(node.inputs[1])
C = GpuAllocEmpty(dtype='float16')( C = GpuAllocEmpty(dtype='float16')(
shape_i(A, 0, fgraph), shape_i(B, 1, fgraph)) shape_i(A, 0, fgraph), shape_i(B, 1, fgraph))
return [host_from_gpu(Gemm16()(C, 1.0, A, B, 0.0))] return Gemm16()(C, 1.0, A, B, 0.0)
@opt.register_opt() @opt.register_opt()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论