提交 98da3fcc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use inplace_allocempty for blas inplace optimizers.

上级 4788a5a7
...@@ -3,12 +3,13 @@ import os.path ...@@ -3,12 +3,13 @@ import os.path
from theano import Apply, config from theano import Apply, config
from theano.compile import optdb from theano.compile import optdb
from theano.gof import local_optimizer, LocalOptGroup from theano.gof import LocalOptGroup
from theano.tensor.basic import as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from .basic_ops import (HideC, as_gpuarray_variable, GpuAllocEmpty, from .basic_ops import HideC, as_gpuarray_variable, infer_context_name
infer_context_name)
from .opt_util import inplace_allocempty
try: try:
import pygpu import pygpu
...@@ -295,37 +296,19 @@ class GpuDot22(BlasOp): ...@@ -295,37 +296,19 @@ class GpuDot22(BlasOp):
gpu_dot22 = GpuDot22() gpu_dot22 = GpuDot22()
@local_optimizer([gpugemv_no_inplace], inplace=True) @inplace_allocempty(GpuGemv, 0)
def local_inplace_gpuagemv(node): def local_inplace_gpuagemv(node, inputs):
if node.op == gpugemv_no_inplace: return [gpugemv_inplace(*inputs)]
inputs = list(node.inputs)
y = inputs[0]
if (y.owner and isinstance(y.owner.op, GpuAllocEmpty) and @inplace_allocempty(GpuGemm, 0)
len(y.clients) > 1): def local_inplace_gpuagemm(node, inputs):
inputs[0] = y.owner.op(*y.owner.inputs) return [gpugemm_inplace(*inputs)]
return [gpugemv_inplace(*inputs)]
@inplace_allocempty(GpuGer, 0)
@local_optimizer([gpugemm_no_inplace], inplace=True) def local_inplace_gpuager(node, inputs):
def local_inplace_gpuagemm(node): return [gpuger_inplace(*inputs)]
if node.op == gpugemm_no_inplace:
inputs = list(node.inputs)
C = inputs[0]
if (C.owner and isinstance(C.owner.op, GpuAllocEmpty) and
len(C.clients) > 1):
inputs[0] = C.owner.op(*C.owner.inputs)
return [gpugemm_inplace(*inputs)]
@local_optimizer([gpuger_no_inplace], inplace=True)
def local_inplace_gpuager(node):
if node.op == gpuger_no_inplace:
inputs = list(node.inputs)
A = inputs[0]
if (A.owner and isinstance(A.owner.op, GpuAllocEmpty) and
len(A.clients) > 1):
inputs[0] = A.owner.op(*A.owner.inputs)
return [gpuger_inplace(*inputs)]
gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv, gpuablas_opt_inplace = in2out(LocalOptGroup(local_inplace_gpuagemv,
local_inplace_gpuagemm, local_inplace_gpuagemm,
......
...@@ -294,7 +294,7 @@ def inplace_allocempty(op, idx): ...@@ -294,7 +294,7 @@ def inplace_allocempty(op, idx):
function can be as simple as: function can be as simple as:
def maker(node, inputs): def maker(node, inputs):
return node.op.__class__(inplace=True)(*inputs) return [node.op.__class__(inplace=True)(*inputs)]
Parameters Parameters
---------- ----------
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论