提交 49d99209 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6013 from abergeron/gemm3d

Enable float16 for GpuGemmBatch.
...@@ -401,6 +401,7 @@ gpu_dot22 = GpuDot22() ...@@ -401,6 +401,7 @@ gpu_dot22 = GpuDot22()
class GpuGemmBatch(BlasOp): class GpuGemmBatch(BlasOp):
__props__ = ('inplace',) __props__ = ('inplace',)
_f16_ok = True
def __init__(self, inplace=False): def __init__(self, inplace=False):
self.inplace = inplace self.inplace = inplace
...@@ -413,13 +414,21 @@ class GpuGemmBatch(BlasOp): ...@@ -413,13 +414,21 @@ class GpuGemmBatch(BlasOp):
B = as_gpuarray_variable(B, ctx_name) B = as_gpuarray_variable(B, ctx_name)
C = as_gpuarray_variable(C, ctx_name) C = as_gpuarray_variable(C, ctx_name)
alpha = as_tensor_variable(alpha) alpha = as_tensor_variable(alpha)
if alpha.dtype == 'float16':
alpha = alpha.astype('float32')
beta = as_tensor_variable(beta) beta = as_tensor_variable(beta)
if beta.dtype == 'float16':
beta = beta.astype('float32')
assert alpha.ndim == 0 assert alpha.ndim == 0
assert beta.ndim == 0 assert beta.ndim == 0
assert A.ndim == 3 assert A.ndim == 3
assert B.ndim == 3 assert B.ndim == 3
assert C.ndim == 3 assert C.ndim == 3
assert A.dtype == B.dtype == C.dtype == alpha.dtype == beta.dtype assert A.dtype == B.dtype == C.dtype
if A.dtype in ('float32', 'float64'):
assert A.dtype == alpha.dtype == beta.dtype
else:
assert 'float32' == alpha.dtype == beta.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()]) return Apply(self, [C, alpha, A, B, beta], [C.type()])
def c_headers(self): def c_headers(self):
...@@ -1726,9 +1735,15 @@ def local_inplace_gpuagemm(node, inputs): ...@@ -1726,9 +1735,15 @@ def local_inplace_gpuagemm(node, inputs):
def local_inplace_gpuager(node, inputs): def local_inplace_gpuager(node, inputs):
return [gpuger_inplace(*inputs)] return [gpuger_inplace(*inputs)]
@inplace_allocempty(GpuGemmBatch, 0)
def local_inplace_gpuagemmbatch(node, inputs):
return [gpugemmbatch_inplace(*inputs)]
gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv, gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv,
local_inplace_gpuagemm, local_inplace_gpuagemm,
local_inplace_gpuager), local_inplace_gpuager,
local_inplace_gpuagemmbatch),
name='gpuablas_opt_inplace') name='gpuablas_opt_inplace')
optdb.register('InplaceGpuaBlasOpt', optdb.register('InplaceGpuaBlasOpt',
......
...@@ -1230,7 +1230,7 @@ def local_gpua_gemm(op, context_name, inputs, outputs): ...@@ -1230,7 +1230,7 @@ def local_gpua_gemm(op, context_name, inputs, outputs):
@op_lifter([tensor.blas.BatchedDot]) @op_lifter([tensor.blas.BatchedDot])
@register_opt2([tensor.blas.BatchedDot], 'fast_compile') @register_opt2([tensor.blas.BatchedDot], 'fast_compile')
def local_gpua_gemmbatch(op, context_name, inputs, outputs): def local_gpua_gemmbatch(op, context_name, inputs, outputs):
if inputs[0].dtype not in ['float32', 'float64']: if inputs[0].dtype not in ['float16', 'float32', 'float64']:
return return
a, b = inputs a, b = inputs
# Since GpuGemmBatch only supports 3D inputs and output, # Since GpuGemmBatch only supports 3D inputs and output,
...@@ -1252,7 +1252,8 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs): ...@@ -1252,7 +1252,8 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs):
if b.dtype != out_dtype: if b.dtype != out_dtype:
b = gpu_cast_op(b) b = gpu_cast_op(b)
c = tensor.AllocEmpty(out_dtype)(a.shape[0], a.shape[1], b.shape[2]) c = GpuAllocEmpty(out_dtype, context_name)(
a.shape[0], a.shape[1], b.shape[2])
out = gpugemmbatch_no_inplace(c, np.asarray(1.0, dtype=out_dtype), out = gpugemmbatch_no_inplace(c, np.asarray(1.0, dtype=out_dtype),
a, b, np.asarray(0.0, dtype=out_dtype)) a, b, np.asarray(0.0, dtype=out_dtype))
if len(output_dims) != 3: if len(output_dims) != 3:
......
...@@ -16,7 +16,7 @@ from .config import mode_with_gpu, test_ctx_name ...@@ -16,7 +16,7 @@ from .config import mode_with_gpu, test_ctx_name
from .test_basic_ops import makeTester, rand from .test_basic_ops import makeTester, rand
from ..blas import (gpugemv_inplace, gpugemv_no_inplace, from ..blas import (gpugemv_inplace, gpugemv_no_inplace,
gpugemm_inplace, gpugemm_no_inplace, gpugemm_inplace, gpugemm_no_inplace,
gpugemmbatch_no_inplace, gpugemmbatch_inplace,
gpuger_inplace, gpuger_no_inplace, gpuger_inplace, gpuger_no_inplace,
GpuGer, GpuGemm, gpu_dot22) GpuGer, GpuGemm, gpu_dot22)
...@@ -130,7 +130,12 @@ gemm_batched_tests = dict( ...@@ -130,7 +130,12 @@ gemm_batched_tests = dict(
("test_b%im%ik%in%i" % (b, m, k, n), ("test_b%im%ik%in%i" % (b, m, k, n),
[rand(b, m, n), rand(), rand(b, m, k), rand(b, k, n), rand()]) [rand(b, m, n), rand(), rand(b, m, k), rand(b, k, n), rand()])
for b, m, k, n in itertools.combinations([2, 3, 5, 7, 11, 13], 4)) for b, m, k, n in itertools.combinations([2, 3, 5, 7, 11, 13], 4))
# float16 not supported
gemm_batched_tests['float16'] = [rand(3, 4, 7).astype('float16'),
rand().astype('float16'),
rand(3, 4, 4).astype('float16'),
rand(3, 4, 7).astype('float16'),
rand().astype('float16')]
gemm_batched_tests['float32'] = [rand(3, 4, 7).astype('float32'), gemm_batched_tests['float32'] = [rand(3, 4, 7).astype('float32'),
rand().astype('float32'), rand().astype('float32'),
rand(3, 4, 4).astype('float32'), rand(3, 4, 4).astype('float32'),
...@@ -146,7 +151,7 @@ gemm_batched_tests['float64'] = [rand(3, 4, 7).astype('float64'), ...@@ -146,7 +151,7 @@ gemm_batched_tests['float64'] = [rand(3, 4, 7).astype('float64'),
GpuGemmBatchTester = makeTester( GpuGemmBatchTester = makeTester(
'GpuGemmBatchTester', 'GpuGemmBatchTester',
op=lambda z, alpha, x, y, beta: alpha * batched_dot(x, y) + beta * z, op=lambda z, alpha, x, y, beta: alpha * batched_dot(x, y) + beta * z,
gpu_op=gpugemmbatch_no_inplace, gpu_op=gpugemmbatch_inplace,
cases=gemm_batched_tests cases=gemm_batched_tests
) )
...@@ -161,6 +166,7 @@ class TestGpuGemmBatchStrided(TestCase): ...@@ -161,6 +166,7 @@ class TestGpuGemmBatchStrided(TestCase):
x_num = np.arange(32 * 19 * 600, dtype=config.floatX).reshape((32, 19, 600)) x_num = np.arange(32 * 19 * 600, dtype=config.floatX).reshape((32, 19, 600))
y_num = np.arange(7 * 32 * 600, dtype=config.floatX).reshape((32, 7, 600)) y_num = np.arange(7 * 32 * 600, dtype=config.floatX).reshape((32, 7, 600))
f(x_num, y_num) f(x_num, y_num)
assert f.maker.fgraph.toposort()[-2].op.inplace
class TestGpuSger(TestGer): class TestGpuSger(TestGer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论