提交 61d9631c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove the restriction on alpha and beta dtype in gpuarray blas ops.

上级 3cef522a
......@@ -28,7 +28,7 @@ class GpuGemv(BlasOp, Gemv):
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
assert A.dtype == x.dtype == y.dtype == alpha.dtype == beta.dtype
assert A.dtype == x.dtype == y.dtype
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
......@@ -91,7 +91,7 @@ class GpuGemm(BlasOp, Gemm):
A = as_gpuarray_variable(A)
B = as_gpuarray_variable(B)
C = as_gpuarray_variable(C)
assert A.dtype == B.dtype == C.dtype == alpha.dtype == beta.dtype
assert A.dtype == B.dtype == C.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()])
def perform(self, node, inputs, outputs):
......@@ -155,7 +155,7 @@ class GpuGer(BlasOp, Ger):
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
assert A.dtype == x.dtype == y.dtype == alpha.dtype
assert A.dtype == x.dtype == y.dtype
return Apply(self, [A, alpha, x, y], [A.type()])
def perform(self, node, inp, out):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论