提交 8735401b authored 作者: Frederic Bastien's avatar Frederic Bastien

Use gpugemm for gemv in float16. fix gh-4341

上级 27a1361d
...@@ -1171,6 +1171,14 @@ def local_gpua_careduce(op, context_name, inputs, outputs): ...@@ -1171,6 +1171,14 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
@op_lifter([tensor.blas.Gemv, tensor.blas_c.CGemv]) @op_lifter([tensor.blas.Gemv, tensor.blas_c.CGemv])
@register_opt2([tensor.blas.Gemv], 'fast_compile') @register_opt2([tensor.blas.Gemv], 'fast_compile')
def local_gpua_gemv(op, context_name, inputs, outputs): def local_gpua_gemv(op, context_name, inputs, outputs):
if inputs[0].dtype == 'float16':
# Use gemm implementation as cublas gemv don't support float16
return gpugemm_no_inplace(inputs[0][:, None],
inputs[1],
inputs[2],
inputs[3][:, None],
inputs[4]).dimshuffle(0)
if inputs[0].dtype not in ['float32', 'float64']: if inputs[0].dtype not in ['float32', 'float64']:
return return
if op.inplace: if op.inplace:
......
...@@ -18,7 +18,7 @@ from ..blas import (gpugemv_inplace, gpugemv_no_inplace, ...@@ -18,7 +18,7 @@ from ..blas import (gpugemv_inplace, gpugemv_no_inplace,
gpugemm_inplace, gpugemm_no_inplace, gpugemm_inplace, gpugemm_no_inplace,
gpugemmbatch_no_inplace, gpugemmbatch_no_inplace,
gpuger_inplace, gpuger_no_inplace, gpuger_inplace, gpuger_no_inplace,
GpuGer, gpu_dot22) GpuGer, GpuGemm, gpu_dot22)
GpuGemvTester = makeTester( GpuGemvTester = makeTester(
...@@ -42,6 +42,22 @@ GpuGemvTester = makeTester( ...@@ -42,6 +42,22 @@ GpuGemvTester = makeTester(
def test_float16(): def test_float16():
# gemv (gemm called)
float16_data = [rand(3).astype('float16'),
np.asarray(1, dtype=np.float32),
rand(3, 3).astype('float16'),
rand(3).astype('float16'),
np.asarray(0.5, dtype=np.float32)]
float16_shared = [gpuarray_shared_constructor(val, target=test_ctx_name)
for val in float16_data]
o = gemm(*float16_shared)
f = theano.function([], o, mode=mode_with_gpu)
y, alpha, A, x, beta = float16_data
out = f()
utt.assert_allclose(np.asarray(out), alpha * np.dot(A, x) + beta * y)
topo = f.maker.fgraph.toposort()
assert any([isinstance(n.op, GpuGemm) for n in topo])
# gemm # gemm
float16_data = [rand(3, 3).astype('float16'), float16_data = [rand(3, 3).astype('float16'),
np.asarray(1, dtype=np.float32), np.asarray(1, dtype=np.float32),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论