提交 8451a929 authored 作者: James Bergstra's avatar James Bergstra

Adding scipy version of Ger

It's a little faster than the numpy implementation, but not much.
上级 e105f562
...@@ -6,6 +6,7 @@ from basic import * ...@@ -6,6 +6,7 @@ from basic import *
import opt import opt
import opt_uncanonicalize import opt_uncanonicalize
import blas import blas
import blas_scipy
import xlogx import xlogx
import raw_random, randomstreams import raw_random, randomstreams
......
...@@ -238,9 +238,9 @@ class Ger(Op): ...@@ -238,9 +238,9 @@ class Ger(Op):
def __str__(self): def __str__(self):
if self.destructive: if self.destructive:
return 'Ger{destructive}' return '%s{destructive}' % self.__class__.__name__
else: else:
return 'Ger{non-destructive}' return '%s{non-destructive}' % self.__class__.__name__
def make_node(self, A, alpha, x, y): def make_node(self, A, alpha, x, y):
A = T.as_tensor_variable(A) A = T.as_tensor_variable(A)
...@@ -1418,11 +1418,12 @@ blas_optdb.register('local_gemm_to_gemv', ...@@ -1418,11 +1418,12 @@ blas_optdb.register('local_gemm_to_gemv',
# After destroyhandler is in but before we try to make elemwise things inplace # After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71) # Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt', blas_opt_inplace = EquilibriumOptimizer(
EquilibriumOptimizer(
[local_inplace_gemm, local_inplace_gemv, local_inplace_ger], [local_inplace_gemm, local_inplace_gemv, local_inplace_ger],
failure_callback=EquilibriumOptimizer.warn_inplace, failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5), max_use_ratio=5)
optdb.register('InplaceBlasOpt',
blas_opt_inplace,
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
class Dot22Scalar(GemmRelated): class Dot22Scalar(GemmRelated):
......
"""
Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
import numpy
from blas import Ger, ger, ger_destructive
from blas import Gemv
from blas import Gemm
from blas import blas_optdb, optdb,local_optimizer, EquilibriumOptimizer
try:
import scipy.linalg.blas
have_fblas = True
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
}
_blas_ger_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sger,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dger,
#numpy.dtype('complex64'):scipy.linalg.blas.fblas.cger,
#numpy.dtype('complex128'):scipy.linalg.blas.fblas.zger,
}
except ImportError, e:
have_fblas = False
class ScipyGer(Ger):
# keep everything else, but override the make_thunk
def make_thunk(self, node, storage_map, compute_map, no_recycling):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
# get vars for containers
cA, calpha, cx, cy = node_input_storage
cZ, = node_output_storage
local_ger = _blas_ger_fns[numpy.dtype(node.inputs[0].type.dtype)]
def rval():
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = local_ger(calpha[0], cx[0], cy[0], a=cA[0],
overwrite_a=int(self.destructive))
cZ[0] = A
#TODO: If this is currently an unofficial part of the thunk API,
# then maybe it should be documented and made official?
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
@local_optimizer([ger, ger_destructive])
def use_scipy_ger(node):
if node.op == ger:
return [ScipyGer(False)(*node.inputs)]
@local_optimizer([ScipyGer(False)])
def make_ger_destructive(node):
if node.op == ScipyGer(False):
return [ScipyGer(True)(*node.inputs)]
use_scipy_blas = EquilibriumOptimizer(
[use_scipy_ger],
max_use_ratio=5)
make_scipy_blas_destructive = EquilibriumOptimizer(
[make_ger_destructive],
max_use_ratio=5)
if have_fblas:
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks, but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register('scipy_blas',
use_scipy_blas,
100, 'fast_run')
# this matches the InplaceBlasOpt defined in blas.py
optdb.register('make_scipy_blas_destructive',
make_scipy_blas_destructive,
70.0, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论