提交 3fce8613 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix blas C code having wrong alpha and beta for float16.

上级 f86883fc
......@@ -47,8 +47,8 @@ class GpuGemv(BlasOp):
A = as_gpuarray_variable(A, ctx_name)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64')
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
......@@ -128,8 +128,8 @@ class GpuGemm(BlasOp):
A = as_gpuarray_variable(A, ctx_name)
B = as_gpuarray_variable(B, ctx_name)
C = as_gpuarray_variable(C, ctx_name)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64')
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
......@@ -208,7 +208,7 @@ class GpuGer(BlasOp):
A = as_gpuarray_variable(A, ctx_name)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha)
alpha = as_tensor_variable(alpha).astype('float64')
assert alpha.ndim == 0
assert A.ndim == 2
assert x.ndim == 1
......@@ -345,8 +345,8 @@ class GpuGemmBatch(BlasOp):
A = as_gpuarray_variable(A, ctx_name)
B = as_gpuarray_variable(B, ctx_name)
C = as_gpuarray_variable(C, ctx_name)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
alpha = as_tensor_variable(alpha).astype('float64')
beta = as_tensor_variable(beta).astype('float64')
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论