提交 a6174c22 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Enable float16 for GpuGemmBatch.

上级 c35ef4d8
......@@ -401,6 +401,7 @@ gpu_dot22 = GpuDot22()
class GpuGemmBatch(BlasOp):
__props__ = ('inplace',)
_f16_ok = True
def __init__(self, inplace=False):
self.inplace = inplace
......@@ -413,13 +414,21 @@ class GpuGemmBatch(BlasOp):
B = as_gpuarray_variable(B, ctx_name)
C = as_gpuarray_variable(C, ctx_name)
alpha = as_tensor_variable(alpha)
if alpha.dtype == 'float16':
alpha = alpha.astype('float32')
beta = as_tensor_variable(beta)
if beta.dtype == 'float16':
beta = beta.astype('float32')
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 3
assert B.ndim == 3
assert C.ndim == 3
assert A.dtype == B.dtype == C.dtype == alpha.dtype == beta.dtype
assert A.dtype == B.dtype == C.dtype
if A.dtype in ('float32', 'float64'):
assert A.dtype == alpha.dtype == beta.dtype
else:
assert 'float32' == alpha.dtype == beta.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()])
def c_headers(self):
......@@ -1726,9 +1735,14 @@ def local_inplace_gpuagemm(node, inputs):
def local_inplace_gpuager(node, inputs):
return [gpuger_inplace(*inputs)]
@inplace_allocempty(GpuGemmBatch, 0)
def local_inplace_gpuagemmbatch(node, inputs):
return [gpugemmbatch_inplace(*inputs)]
gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv,
local_inplace_gpuagemm,
local_inplace_gpuager),
local_inplace_gpuager,
local_inplace_gpuagemmbatch),
name='gpuablas_opt_inplace')
optdb.register('InplaceGpuaBlasOpt',
......
......@@ -1213,7 +1213,7 @@ def local_gpua_gemm(op, context_name, inputs, outputs):
@op_lifter([tensor.blas.BatchedDot])
@register_opt2([tensor.blas.BatchedDot], 'fast_compile')
def local_gpua_gemmbatch(op, context_name, inputs, outputs):
if inputs[0].dtype not in ['float32', 'float64']:
if inputs[0].dtype not in ['float16', 'float32', 'float64']:
return
a, b = inputs
# Since GpuGemmBatch only supports 3D inputs and output,
......
......@@ -16,7 +16,7 @@ from .config import mode_with_gpu, test_ctx_name
from .test_basic_ops import makeTester, rand
from ..blas import (gpugemv_inplace, gpugemv_no_inplace,
gpugemm_inplace, gpugemm_no_inplace,
gpugemmbatch_no_inplace,
gpugemmbatch_inplace,
gpuger_inplace, gpuger_no_inplace,
GpuGer, GpuGemm, gpu_dot22)
......@@ -130,7 +130,12 @@ gemm_batched_tests = dict(
("test_b%im%ik%in%i" % (b, m, k, n),
[rand(b, m, n), rand(), rand(b, m, k), rand(b, k, n), rand()])
for b, m, k, n in itertools.combinations([2, 3, 5, 7, 11, 13], 4))
# float16 not supported
gemm_batched_tests['float16'] = [rand(3, 4, 7).astype('float16'),
rand().astype('float16'),
rand(3, 4, 4).astype('float16'),
rand(3, 4, 7).astype('float16'),
rand().astype('float16')]
gemm_batched_tests['float32'] = [rand(3, 4, 7).astype('float32'),
rand().astype('float32'),
rand(3, 4, 4).astype('float32'),
......@@ -146,7 +151,7 @@ gemm_batched_tests['float64'] = [rand(3, 4, 7).astype('float64'),
GpuGemmBatchTester = makeTester(
'GpuGemmBatchTester',
op=lambda z, alpha, x, y, beta: alpha * batched_dot(x, y) + beta * z,
gpu_op=gpugemmbatch_no_inplace,
gpu_op=gpugemmbatch_inplace,
cases=gemm_batched_tests
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论