提交 f8ba9bb1 authored 作者: Thomas George's avatar Thomas George

added op lifter for slinalg.solve

上级 1494d16f
...@@ -31,6 +31,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv, ...@@ -31,6 +31,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
AbstractConv3d_gradWeights, AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs) AbstractConv3d_gradInputs)
import theano.tensor.signal.pool as pool import theano.tensor.signal.pool as pool
import theano.tensor.slinalg as slinalg
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
...@@ -68,6 +69,7 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor, ...@@ -68,6 +69,7 @@ from .subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedIncSubtensor1_dev20) GpuAdvancedIncSubtensor1_dev20)
from .opt_util import alpha_merge, output_merge, pad_dims, unpad_dims from .opt_util import alpha_merge, output_merge, pad_dims, unpad_dims
from .reduction import GpuMaxAndArgmax from .reduction import GpuMaxAndArgmax
from .linalg import GpuCusolverSolve
_logger = logging.getLogger("theano.gpuarray.opt") _logger = logging.getLogger("theano.gpuarray.opt")
...@@ -1884,6 +1886,14 @@ def _scan_type_infer(node): ...@@ -1884,6 +1886,14 @@ def _scan_type_infer(node):
def local_gpu_maxandargmax(op, context_name, inputs, outputs): def local_gpu_maxandargmax(op, context_name, inputs, outputs):
return GpuMaxAndArgmax(op.get_params(None)) return GpuMaxAndArgmax(op.get_params(None))
# solve
@register_opt('fast_compile')
@op_lifter([theano.tensor.slinalg.Solve])
@register_opt2([theano.tensor.slinalg.Solve], 'fast_compile')
def local_gpu_solve(op, context_name, inputs, outputs):
return GpuCusolverSolve()
# Do not register in fast_run or fast_compile. # Do not register in fast_run or fast_compile.
# It will be added to fast_run if the GPU is enabled. # It will be added to fast_run if the GPU is enabled.
optdb.register('gpua_scanOp_make_inplace', optdb.register('gpua_scanOp_make_inplace',
......
...@@ -4,6 +4,7 @@ from nose.tools import assert_raises ...@@ -4,6 +4,7 @@ from nose.tools import assert_raises
import theano import theano
from theano import tensor from theano import tensor
import theano.tensor.slinalg as slinalg
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
from theano.tests import unittest_tools as utt, test_ifelse from theano.tests import unittest_tools as utt, test_ifelse
from theano.tensor.tests import test_basic from theano.tensor.tests import test_basic
...@@ -16,6 +17,7 @@ from ..basic_ops import ( ...@@ -16,6 +17,7 @@ from ..basic_ops import (
from ..blas import GpuGemm from ..blas import GpuGemm
from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise from ..elemwise import GpuCAReduceCuda, GpuCAReduceCPY, GpuElemwise
from ..subtensor import GpuSubtensor from ..subtensor import GpuSubtensor
from ..linalg import GpuCusolverSolve
from .config import mode_with_gpu, test_ctx_name from .config import mode_with_gpu, test_ctx_name
...@@ -496,3 +498,18 @@ def test_no_complex(): ...@@ -496,3 +498,18 @@ def test_no_complex():
stft_out = tensor.exp(width_var * freq_var) * signal_var stft_out = tensor.exp(width_var * freq_var) * signal_var
theano.function([width_var, freq_var, signal_var], stft_out, theano.function([width_var, freq_var, signal_var], stft_out,
mode=mode_with_gpu) mode=mode_with_gpu)
def test_local_lift_solve():
A = tensor.fmatrix()
b = tensor.fmatrix()
o = slinalg.solve(A, b)
f_cpu = theano.function([A, b], o)
f_gpu = theano.function([A, b], o, mode=mode_with_gpu)
assert not any(isinstance(n.op, slinalg.Solve)
for n in f_gpu.maker.fgraph.apply_nodes)
assert any(isinstance(n.op, GpuCusolverSolve)
for n in f_gpu.maker.fgraph.apply_nodes)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
b_val = numpy.random.uniform(-0.4, 0.4, (5, 3)).astype("float32")
utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论