提交 a9564d15 authored 作者: Frederic Bastien's avatar Frederic Bastien

Upcast the alpha/beta scalar input if it allow to work on the GPU.

上级 147bcead
...@@ -59,10 +59,13 @@ class GpuGemv(BlasOp): ...@@ -59,10 +59,13 @@ class GpuGemv(BlasOp):
assert x.ndim == 1 assert x.ndim == 1
assert y.ndim == 1 assert y.ndim == 1
assert A.dtype == x.dtype == y.dtype assert A.dtype == x.dtype == y.dtype
if A.dtype == 'float16':
assert alpha.dtype == beta.dtype == 'float32' # float16 not supported
else: expected = A.dtype
assert alpha.dtype == beta.dtype == A.dtype assert theano.scalar.upcast(alpha.dtype,
beta.dtype, expected) == expected
alpha = alpha.astype(expected)
beta = beta.astype(expected)
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
...@@ -173,12 +176,18 @@ class GpuGemm(BlasOp): ...@@ -173,12 +176,18 @@ class GpuGemm(BlasOp):
raise TypeError(theano.tensor.blas.Gemm.E_mixed, raise TypeError(theano.tensor.blas.Gemm.E_mixed,
(A.dtype, B.dtype, C.dtype, (A.dtype, B.dtype, C.dtype,
alpha.dtype, beta.dtype)) alpha.dtype, beta.dtype))
if A.dtype == 'float16':
assert alpha.dtype == beta.dtype == 'float32'
else:
assert alpha.dtype == beta.dtype == A.dtype
if not A.dtype.startswith('float'): if not A.dtype.startswith('float'):
raise TypeError(theano.tensor.blas.Gemm.E_float, (A.dtype)) raise TypeError(theano.tensor.blas.Gemm.E_float, (A.dtype))
if A.dtype == 'float16':
expected = 'float32'
else:
expected = A.dtype
assert theano.scalar.upcast(alpha.dtype,
beta.dtype, expected) == expected
alpha = alpha.astype(expected)
beta = beta.astype(expected)
assert alpha.ndim == 0 assert alpha.ndim == 0
assert beta.ndim == 0 assert beta.ndim == 0
assert A.ndim == 2 assert A.ndim == 2
...@@ -257,10 +266,12 @@ class GpuGer(BlasOp): ...@@ -257,10 +266,12 @@ class GpuGer(BlasOp):
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name) y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha) alpha = as_tensor_variable(alpha)
if not(A.dtype == x.dtype == y.dtype == alpha.dtype): if not(A.dtype == x.dtype == y.dtype):
raise TypeError('ger requires matching dtypes', raise TypeError('ger requires matching dtypes',
(A.dtype, alpha.dtype, x.dtype, y.dtype)) (A.dtype, alpha.dtype, x.dtype, y.dtype))
assert theano.scalar.upcast(alpha.dtype, A.dtype) == A.dtype
alpha = alpha.astype(A.dtype)
assert alpha.ndim == 0 assert alpha.ndim == 0
assert A.ndim == 2 assert A.ndim == 2
assert x.ndim == 1 assert x.ndim == 1
......
...@@ -1175,6 +1175,8 @@ def local_gpua_gemv(op, context_name, inputs, outputs): ...@@ -1175,6 +1175,8 @@ def local_gpua_gemv(op, context_name, inputs, outputs):
@op_lifter([tensor.blas.Gemm]) @op_lifter([tensor.blas.Gemm])
@register_opt2([tensor.blas.Gemm], 'fast_compile') @register_opt2([tensor.blas.Gemm], 'fast_compile')
def local_gpua_gemm(op, context_name, inputs, outputs): def local_gpua_gemm(op, context_name, inputs, outputs):
if inputs[0].dtype not in ['float32', 'float64']:
return
if op.inplace: if op.inplace:
return gpugemm_inplace return gpugemm_inplace
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论