提交 f8ba9bb1 authored 作者: Thomas George's avatar Thomas George

added op lifter for slinalg.solve

上级 1494d16f
......@@ -31,6 +31,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs)
import theano.tensor.signal.pool as pool
import theano.tensor.slinalg as slinalg
from theano.tests.breakpoint import PdbBreakpoint
......@@ -68,6 +69,7 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedIncSubtensor1_dev20)
from .opt_util import alpha_merge, output_merge, pad_dims, unpad_dims
from .reduction import GpuMaxAndArgmax
from .linalg import GpuCusolverSolve
_logger = logging.getLogger("theano.gpuarray.opt")
......@@ -1884,6 +1886,14 @@ def _scan_type_infer(node):
def local_gpu_maxandargmax(op, context_name, inputs, outputs):
return GpuMaxAndArgmax(op.get_params(None))
# solve
@register_opt('fast_compile')
@op_lifter([theano.tensor.slinalg.Solve])
@register_opt2([theano.tensor.slinalg.Solve], 'fast_compile')
def local_gpu_solve(op, context_name, inputs, outputs):
return GpuCusolverSolve()
# Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled.
optdb.register('gpua_scanOp_make_inplace',
......
......@@ -4,6 +4,7 @@ from nose.tools import assert_raises
import theano
from theano import tensor
import theano.tensor.slinalg as slinalg
from theano.tests.breakpoint import PdbBreakpoint
from theano.tests import unittest_tools as utt, test_ifelse
from theano.tensor.tests import test_basic
......@@ -16,6 +17,7 @@ from ..basic_ops import (
from ..blas import GpuGemm
from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise
from ..subtensor import GpuSubtensor
from ..linalg import GpuCusolverSolve
from .config import mode_with_gpu, test_ctx_name
......@@ -496,3 +498,18 @@ def test_no_complex():
stft_out = tensor.exp(width_var * freq_var) * signal_var
theano.function([width_var, freq_var, signal_var], stft_out,
mode=mode_with_gpu)
def test_local_lift_solve():
A = tensor.fmatrix()
b = tensor.fmatrix()
o = slinalg.solve(A, b)
f_cpu = theano.function([A, b], o)
f_gpu = theano.function([A, b], o, mode=mode_with_gpu)
assert not any(isinstance(n.op, slinalg.Solve)
for n in f_gpu.maker.fgraph.apply_nodes)
assert any(isinstance(n.op, GpuCusolverSolve)
for n in f_gpu.maker.fgraph.apply_nodes)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
b_val = numpy.random.uniform(-0.4, 0.4, (5, 3)).astype("float32")
utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论