提交 f8de0e04 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make blas ops pass context correctly.

上级 5b6448e2
......@@ -4,11 +4,11 @@ from theano import Apply, config
from theano.compile import optdb
from theano.gof import local_optimizer, LocalOptGroup
from theano.tensor.basic import as_tensor_variable
from theano.tensor.blas import Dot22, Gemv, Gemm, Ger
from theano.tensor.opt import in2out
from .basic_ops import HideC, as_gpuarray_variable, GpuAllocEmpty
from .basic_ops import (HideC, as_gpuarray_variable, GpuAllocEmpty,
infer_context_name)
try:
import pygpu
......@@ -28,34 +28,14 @@ class BlasOp(HideC):
def c_init_code(self):
return ['import_pygpu__blas();']
def c_support_code(self):
return """
PyGpuArrayObject *gpublas_try_copy(PyGpuArrayObject *out,
PyGpuArrayObject *y) {
if (out &&
GpuArray_CHKFLAGS(&out->ga, GA_CARRAY) &&
theano_size_check(out, PyGpuArray_NDIM(y),
PyGpuArray_DIMS(y),
y->ga.typecode)) {
if (pygpu_move(out, y)) {
Py_XDECREF(out);
return NULL;
}
} else {
Py_XDECREF(out);
out = pygpu_copy(y, GA_ANY_ORDER);
}
return out;
}
"""
class GpuGemv(BlasOp, Gemv):
def make_node(self, y, alpha, A, x, beta):
ctx_name = infer_context_name(y, A, x)
Gemv.make_node(self, y, alpha, A, x, beta)
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
A = as_gpuarray_variable(A, ctx_name)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
assert A.dtype == x.dtype == y.dtype
return Apply(self, [y, alpha, A, x, beta], [y.type()])
......@@ -73,7 +53,7 @@ class GpuGemv(BlasOp, Gemv):
if self.inplace:
code = """
if (%(y)s->ga.strides[0] <= 0) {
%(out)s = gpublas_try_copy(%(out)s, %(y)s);
%(out)s = theano_try_copy(%(out)s, %(y)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -85,7 +65,7 @@ class GpuGemv(BlasOp, Gemv):
""" % vars
else:
code = """
%(out)s = gpublas_try_copy(%(out)s, %(y)s);
%(out)s = theano_try_copy(%(out)s, %(y)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -106,7 +86,7 @@ class GpuGemv(BlasOp, Gemv):
return code
def c_code_cache_version(self):
return (3,)
return (4,)
gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True)
......@@ -116,11 +96,11 @@ class GpuGemm(BlasOp, Gemm):
_f16_ok = True
def make_node(self, C, alpha, A, B, beta):
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
A = as_gpuarray_variable(A)
B = as_gpuarray_variable(B)
C = as_gpuarray_variable(C)
ctx_name = infer_context_name(C, A, B)
Gemm.make_node(self, C, alpha, A, B, beta)
A = as_gpuarray_variable(A, ctx_name)
B = as_gpuarray_variable(B, ctx_name)
C = as_gpuarray_variable(C, ctx_name)
assert A.dtype == B.dtype == C.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()])
......@@ -138,7 +118,7 @@ class GpuGemm(BlasOp, Gemm):
if self.inplace:
code = """
if (!GpuArray_ISONESEGMENT(&%(C)s->ga)) {
%(out)s = gpublas_try_copy(%(out)s, %(C)s);
%(out)s = theano_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -150,7 +130,7 @@ class GpuGemm(BlasOp, Gemm):
""" % vars
else:
code = """
%(out)s = gpublas_try_copy(%(out)s, %(C)s);
%(out)s = theano_try_copy(%(out)s, %(C)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -171,8 +151,7 @@ class GpuGemm(BlasOp, Gemm):
return code
def c_code_cache_version(self):
return (4,)
return (5,)
gpugemm_no_inplace = GpuGemm(inplace=False)
gpugemm_inplace = GpuGemm(inplace=True)
......@@ -180,10 +159,11 @@ gpugemm_inplace = GpuGemm(inplace=True)
class GpuGer(BlasOp, Ger):
def make_node(self, A, alpha, x, y):
ctx_name = infer_context_name(A, x, y)
Ger.make_node(self, A, alpha, x, y)
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
A = as_gpuarray_variable(A, ctx_name)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
assert A.dtype == x.dtype == y.dtype
return Apply(self, [A, alpha, x, y], [A.type()])
......@@ -201,7 +181,7 @@ class GpuGer(BlasOp, Ger):
if self.destructive:
code = """
if (!GpuArray_ISONESEGMENT(&%(A)s->ga)) {
%(out)s = gpublas_try_copy(%(out)s, %(A)s);
%(out)s = theano_try_copy(%(out)s, %(A)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -213,7 +193,7 @@ class GpuGer(BlasOp, Ger):
""" % vars
else:
code = """
%(out)s = gpublas_try_copy(%(out)s, %(A)s);
%(out)s = theano_try_copy(%(out)s, %(A)s);
if (%(out)s == NULL) {
%(fail)s
}
......@@ -231,7 +211,7 @@ class GpuGer(BlasOp, Ger):
return code
def c_code_cache_version(self):
return (2,)
return (3,)
gpuger_no_inplace = GpuGer(destructive=False)
......@@ -240,9 +220,10 @@ gpuger_inplace = GpuGer(destructive=True)
class GpuDot22(BlasOp, Dot22):
def make_node(self, x, y):
ctx_name = infer_context_name(x, y)
Dot22.make_node(self, x, y)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
assert x.dtype == y.dtype
return Apply(self, [x, y], [x.type()])
......@@ -268,7 +249,7 @@ class GpuDot22(BlasOp, Dot22):
dims[1] = PyGpuArray_DIMS(%(B)s)[1];
if (theano_prep_output(&%(out)s, 2, dims, %(typecode)s, GA_C_ORDER,
pygpu_default_context())) {
%(A)s->ctx)) {
%(fail)s
}
......@@ -287,7 +268,7 @@ class GpuDot22(BlasOp, Dot22):
return code
def c_code_cache_version(self):
return (3,)
return (4,)
gpu_dot22 = GpuDot22()
......@@ -295,7 +276,12 @@ gpu_dot22 = GpuDot22()
@local_optimizer([gpugemv_no_inplace], inplace=True)
def local_inplace_gpuagemv(node):
if node.op == gpugemv_no_inplace:
return [gpugemv_inplace(*node.inputs)]
inputs = list(node.inputs)
y = inputs[0]
if (y.owner and isinstance(y.owner.op, GpuAllocEmpty) and
len(y.clients) > 1):
inputs[0] = y.owner.op(*y.owner.inputs)
return [gpugemv_inplace(*inputs)]
@local_optimizer([gpugemm_no_inplace], inplace=True)
......@@ -312,7 +298,12 @@ def local_inplace_gpuagemm(node):
@local_optimizer([gpuger_no_inplace], inplace=True)
def local_inplace_gpuager(node):
if node.op == gpuger_no_inplace:
return [gpuger_inplace(*node.inputs)]
inputs = list(node.inputs)
A = inputs[0]
if (A.owner and isinstance(A.owner.op, GpuAllocEmpty) and
len(A.clients) > 1):
inputs[0] = A.owner.op(*A.owner.inputs)
return [gpuger_inplace(*inputs)]
gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv,
local_inplace_gpuagemm,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论