提交 61d9631c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove the restriction on alpha and beta dtype in gpuarray blas ops.

上级 3cef522a
...@@ -28,7 +28,7 @@ class GpuGemv(BlasOp, Gemv): ...@@ -28,7 +28,7 @@ class GpuGemv(BlasOp, Gemv):
A = as_gpuarray_variable(A) A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x) x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y) y = as_gpuarray_variable(y)
assert A.dtype == x.dtype == y.dtype == alpha.dtype == beta.dtype assert A.dtype == x.dtype == y.dtype
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
...@@ -91,7 +91,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -91,7 +91,7 @@ class GpuGemm(BlasOp, Gemm):
A = as_gpuarray_variable(A) A = as_gpuarray_variable(A)
B = as_gpuarray_variable(B) B = as_gpuarray_variable(B)
C = as_gpuarray_variable(C) C = as_gpuarray_variable(C)
assert A.dtype == B.dtype == C.dtype == alpha.dtype == beta.dtype assert A.dtype == B.dtype == C.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()]) return Apply(self, [C, alpha, A, B, beta], [C.type()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -155,7 +155,7 @@ class GpuGer(BlasOp, Ger): ...@@ -155,7 +155,7 @@ class GpuGer(BlasOp, Ger):
A = as_gpuarray_variable(A) A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x) x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y) y = as_gpuarray_variable(y)
assert A.dtype == x.dtype == y.dtype == alpha.dtype assert A.dtype == x.dtype == y.dtype
return Apply(self, [A, alpha, x, y], [A.type()]) return Apply(self, [A, alpha, x, y], [A.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论