提交 e971d93b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use op_lifter for Gemm16

上级 e3f69de9
......@@ -8,8 +8,7 @@ from theano.gof import local_optimizer, COp
from theano.scalar import as_scalar, constant
from . import opt
from .basic_ops import (as_gpuarray_variable, gpu_from_host,
host_from_gpu, GpuAllocEmpty)
from .basic_ops import (as_gpuarray_variable, GpuAllocEmpty)
from .opt_util import alpha_merge, output_merge
from .pycuda_helper import ensure_pycuda_context
......@@ -161,18 +160,16 @@ if (GpuKernel_init(&k_%(name)s, c->ops, c->ctx, 1, &bcode, &sz,
@opt.register_opt()
@local_optimizer([tensor.Dot])
@opt.op_lifter([tensor.Dot])
def local_dot_to_gemm16(node):
if (type(node.op) == tensor.Dot and
node.inputs[0].dtype == 'float16' and
node.inputs[1].dtype == 'float16' and
node.inputs[0].ndim == 2 and node.inputs[1].ndim == 2):
A = node.inputs[0]
B = node.inputs[1]
if (A.ndim == 2 and B.ndim == 2 and
A.dtype == 'float16' and B.dtype == 'float16'):
fgraph = node.inputs[0].fgraph
A = gpu_from_host(node.inputs[0])
B = gpu_from_host(node.inputs[1])
C = GpuAllocEmpty(dtype='float16')(
shape_i(A, 0, fgraph), shape_i(B, 1, fgraph))
return [host_from_gpu(Gemm16()(C, 1.0, A, B, 0.0))]
return Gemm16()(C, 1.0, A, B, 0.0)
@opt.register_opt()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论