提交 4788a5a7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make explicit checks int blas ops.

上级 6c40675a
......@@ -4,7 +4,7 @@ from theano import Apply, config
from theano.compile import optdb
from theano.gof import local_optimizer, LocalOptGroup
from theano.tensor.blas import Dot22, Gemv, Gemm, Ger
from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import in2out
from .basic_ops import (HideC, as_gpuarray_variable, GpuAllocEmpty,
......@@ -29,13 +29,21 @@ class BlasOp(HideC):
return ['import_pygpu__blas();']
class GpuGemv(BlasOp, Gemv):
class GpuGemv(BlasOp):
def make_node(self, y, alpha, A, x, beta):
ctx_name = infer_context_name(y, A, x)
Gemv.make_node(self, y, alpha, A, x, beta)
A = as_gpuarray_variable(A, ctx_name)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
assert x.ndim == 1
assert y.ndim == 1
assert alpha.dtype in ['float32', 'float64']
assert beta.dtype in ['float32', 'float64']
assert A.dtype == x.dtype == y.dtype
return Apply(self, [y, alpha, A, x, beta], [y.type()])
......@@ -92,15 +100,23 @@ gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True)
class GpuGemm(BlasOp, Gemm):
class GpuGemm(BlasOp):
_f16_ok = True
def make_node(self, C, alpha, A, B, beta):
ctx_name = infer_context_name(C, A, B)
Gemm.make_node(self, C, alpha, A, B, beta)
A = as_gpuarray_variable(A, ctx_name)
B = as_gpuarray_variable(B, ctx_name)
C = as_gpuarray_variable(C, ctx_name)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
assert alpha.ndim == 0
assert beta.ndim == 0
assert A.ndim == 2
assert B.ndim == 2
assert C.ndim == 2
assert alpha.dtype in ['float32', 'float64']
assert beta.dtype in ['float32', 'float64']
assert A.dtype == B.dtype == C.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()])
......@@ -157,13 +173,18 @@ gpugemm_no_inplace = GpuGemm(inplace=False)
gpugemm_inplace = GpuGemm(inplace=True)
class GpuGer(BlasOp, Ger):
class GpuGer(BlasOp):
def make_node(self, A, alpha, x, y):
ctx_name = infer_context_name(A, x, y)
Ger.make_node(self, A, alpha, x, y)
A = as_gpuarray_variable(A, ctx_name)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
alpha = as_tensor_variable(alpha)
assert alpha.ndim == 0
assert A.ndim == 2
assert x.ndim == 1
assert y.ndim == 1
assert alpha.dtype in ['float32', 'float64']
assert A.dtype == x.dtype == y.dtype
return Apply(self, [A, alpha, x, y], [A.type()])
......@@ -218,12 +239,13 @@ gpuger_no_inplace = GpuGer(destructive=False)
gpuger_inplace = GpuGer(destructive=True)
class GpuDot22(BlasOp, Dot22):
class GpuDot22(BlasOp):
def make_node(self, x, y):
ctx_name = infer_context_name(x, y)
Dot22.make_node(self, x, y)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
assert x.ndim == 2
assert y.ndim == 2
assert x.dtype == y.dtype
return Apply(self, [x, y], [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论