提交 03cc25ea authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6163 from botev/master

Small issues fixed.
...@@ -12,7 +12,6 @@ except ImportError: ...@@ -12,7 +12,6 @@ except ImportError:
from .basic_ops import (as_gpuarray_variable, GpuKernelBase, Kernel, from .basic_ops import (as_gpuarray_variable, GpuKernelBase, Kernel,
infer_context_name) infer_context_name)
from .opt import register_opt2, op_lifter, register_opt
from .type import GpuArrayType, gpu_context_type from .type import GpuArrayType, gpu_context_type
...@@ -578,11 +577,3 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -578,11 +577,3 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
def perform(self, node, inp, out, params): def perform(self, node, inp, out, params):
# Disable the perform method from the CPU version # Disable the perform method from the CPU version
Op.perform(self, node, inp, out, params) Op.perform(self, node, inp, out, params)
@register_opt('fast_compile')
@op_lifter([Images2Neibs])
@register_opt2([Images2Neibs], 'fast_compile')
def local_gpua_images2neibs(op, context_name, inputs, outputs):
if op.mode in ['valid', 'half', 'full', 'ignore_borders', 'wrap_centered']:
return GpuImages2Neibs(op.mode)
...@@ -33,6 +33,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv, ...@@ -33,6 +33,7 @@ from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
AbstractConv3d, AbstractConv3d,
AbstractConv3d_gradWeights, AbstractConv3d_gradWeights,
AbstractConv3d_gradInputs) AbstractConv3d_gradInputs)
from theano.tensor.nnet.neighbours import Images2Neibs
import theano.tensor.nlinalg as nlinalg import theano.tensor.nlinalg as nlinalg
import theano.tensor.signal.pool as pool import theano.tensor.signal.pool as pool
import theano.tensor.slinalg as slinalg import theano.tensor.slinalg as slinalg
...@@ -76,6 +77,7 @@ from .reduction import GpuMaxAndArgmax ...@@ -76,6 +77,7 @@ from .reduction import GpuMaxAndArgmax
from .linalg import (GpuCusolverSolve, MATRIX_STRUCTURES_SOLVE, GpuCholesky, from .linalg import (GpuCusolverSolve, MATRIX_STRUCTURES_SOLVE, GpuCholesky,
cusolver_available, GpuMagmaMatrixInverse, gpu_svd, cusolver_available, GpuMagmaMatrixInverse, gpu_svd,
GpuMagmaCholesky, gpu_qr, GpuMagmaEigh) GpuMagmaCholesky, gpu_qr, GpuMagmaEigh)
from .neighbours import GpuImages2Neibs
_logger = logging.getLogger("theano.gpuarray.opt") _logger = logging.getLogger("theano.gpuarray.opt")
...@@ -2086,6 +2088,14 @@ def local_gpu_maxandargmax(op, context_name, inputs, outputs): ...@@ -2086,6 +2088,14 @@ def local_gpu_maxandargmax(op, context_name, inputs, outputs):
return op return op
@register_opt('fast_compile')
@op_lifter([Images2Neibs])
@register_opt2([Images2Neibs], 'fast_compile')
def local_gpua_images2neibs(op, context_name, inputs, outputs):
if op.mode in ['valid', 'half', 'full', 'ignore_borders', 'wrap_centered']:
return GpuImages2Neibs(op.mode)
# solve # solve
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([slinalg.Solve]) @op_lifter([slinalg.Solve])
......
...@@ -325,6 +325,9 @@ solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True) ...@@ -325,6 +325,9 @@ solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True)
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is lower triangular.""" """Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is lower triangular."""
solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False) solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False)
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is upper triangular.""" """Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is upper triangular."""
# symmetric solves
solve_symmetric = Solve(A_structure='symmetric')
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is symmetric."""
# TODO: Optimizations to replace multiplication by matrix inverse # TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten) # with solve() Op (still unwritten)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论