提交 d18c322f authored 作者: Frederic's avatar Frederic

pep8

上级 927ac9a4
...@@ -39,12 +39,12 @@ Dot22Scalar is a GEMM where b=0 and Z is allocated every time. ...@@ -39,12 +39,12 @@ Dot22Scalar is a GEMM where b=0 and Z is allocated every time.
Gemm is a GEMM in all its generality. Gemm is a GEMM in all its generality.
In the future we can refactor the GemmRelated, Gemm, Dot22 and In the future we can refactor the GemmRelated, Gemm, Dot22 and
Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a normal Gemm, but Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a
with an additional configuration variable that says to ignore the input Z. normal Gemm, but with an additional configuration variable that says
Setting that configuration variable to True would make Gemm2 equivalent to the to ignore the input Z. Setting that configuration variable to True
current Dot22 and Dot22Scalar. This would make the file a lot easier to read, would make Gemm2 equivalent to the current Dot22 and Dot22Scalar.
and save a few hundred lines of library, to say nothing of testing and This would make the file a lot easier to read, and save a few hundred
documentation. lines of library, to say nothing of testing and documentation.
GEMV: Gemv GEMV: Gemv
...@@ -109,21 +109,24 @@ This is complicated, done in GemmOptimizer. ...@@ -109,21 +109,24 @@ This is complicated, done in GemmOptimizer.
Identify Dot22Scalar from Dot22 Identify Dot22Scalar from Dot22
------------------------------- -------------------------------
Dot22 Ops that remain after the GemmOptimizer is done have not qualified as GEMM Dot22 Ops that remain after the GemmOptimizer is done have not
Ops. Still they might be scaled by a factor, in which case we use Dot22Scalar qualified as GEMM Ops. Still they might be scaled by a factor, in
which is like Gemm, but without the b and the Z. In the future it would be good which case we use Dot22Scalar which is like Gemm, but without the b
to merge this into the GemmOptimizer. and the Z. In the future it would be good to merge this into the
GemmOptimizer.
Specialize Gemm to Gemv Specialize Gemm to Gemv
----------------------- -----------------------
If arguments to GEMM are dimshuffled vectors, then we can use GEMV instead. This If arguments to GEMM are dimshuffled vectors, then we can use GEMV
optimization is `local_gemm_to_gemv`. instead. This optimization is `local_gemm_to_gemv`.
""" """
import copy
import logging, copy, os, sys import logging
import os
import sys
import numpy import numpy
import numpy.distutils import numpy.distutils
...@@ -137,7 +140,7 @@ from theano.compile.mode import optdb ...@@ -137,7 +140,7 @@ from theano.compile.mode import optdb
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
import theano.scalar import theano.scalar
import basic as T import basic as T
from theano.tensor.blas_headers import blas_header_text #, cblas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.opt import local_dimshuffle_lift from theano.tensor.opt import local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
...@@ -146,16 +149,17 @@ try: ...@@ -146,16 +149,17 @@ try:
import scipy.linalg.blas import scipy.linalg.blas
_have_fblas = True _have_fblas = True
_blas_gemv_fns = { _blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv, numpy.dtype('float32'): scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv, numpy.dtype('float64'): scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv, numpy.dtype('complex64'): scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv, numpy.dtype('complex128'): scipy.linalg.blas.fblas.zgemv,
} }
except ImportError, e: except ImportError, e:
_have_fblas = False _have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas.fblas. ' _logger.warning('Failed to import scipy.linalg.blas.fblas. '
'Falling back on slower implementations (%s)', str(e)) 'Falling back on slower implementations (%s)', str(e))
class Gemv(Op): class Gemv(Op):
""" """
expression is beta * y + alpha * A x expression is beta * y + alpha * A x
...@@ -166,12 +170,12 @@ class Gemv(Op): ...@@ -166,12 +170,12 @@ class Gemv(Op):
output is a vector that can be inplace on y output is a vector that can be inplace on y
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.inplace=inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map={0:[0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self)==type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -189,35 +193,44 @@ class Gemv(Op): ...@@ -189,35 +193,44 @@ class Gemv(Op):
alpha = T.as_tensor_variable(alpha) alpha = T.as_tensor_variable(alpha)
beta = T.as_tensor_variable(beta) beta = T.as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype: if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError('Gemv requires matching dtypes', (y.dtype, A.dtype, x.dtype)) raise TypeError('Gemv requires matching dtypes',
if A.ndim != 2: raise TypeError('gemv requires matrix for A', A.type) (y.dtype, A.dtype, x.dtype))
if x.ndim != 1: raise TypeError('gemv requires vector for x', x.type) if A.ndim != 2:
if y.ndim != 1: raise TypeError('gemv requires vector for y', y.type) raise TypeError('gemv requires matrix for A', A.type)
if x.ndim != 1:
raise TypeError('gemv requires vector for x', x.type)
if y.ndim != 1:
raise TypeError('gemv requires vector for y', y.type)
if y.broadcastable[0] != A.broadcastable[0]: if y.broadcastable[0] != A.broadcastable[0]:
raise TypeError('broadcastable mismatch between y and A', (y.type, A.type)) raise TypeError('broadcastable mismatch between y and A',
# The following is not grounds for error (y.type, A.type))
# because as long as sizes are 1 at time of perform() there is no problem # The following is not grounds for error because as long as
# sizes are 1 at time of perform() there is no problem
#if x.broadcastable[0] != A.broadcastable[1]: #if x.broadcastable[0] != A.broadcastable[1]:
#raise TypeError('broadcastable mismatch between x and A', (x.type, A.type)) #raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
if _have_fblas and y.shape[0]!=0 and x.shape[0]!=0: if _have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv ' raise ValueError('Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s ' '(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))# % (y.shape, A.shape, x.shape))
#Here I suppose that A is in c order. If we don't make it explicitly #Here I suppose that A is in c order. If we don't make it
# as fortran order, scipy 0.7.2 seam to create a copy in fortran # explicitly as fortran order, scipy 0.7.2 seam to create
# order instead of just reshaping it and using the trans flag. # a copy in fortran order instead of just reshaping it
# and using the trans flag.
#If A is already in fortran order, make it in c order and using the #If A is already in fortran order, make it in c order and using the
# trans flag don't seam to cause slowdown. # trans flag don't seam to cause slowdown.
#out_storage[0][0] = gemv(alpha, A, x, beta, y, overwrite_y=self.inplace) #out_storage[0][0] = gemv(alpha, A, x, beta, y,
out_storage[0][0] = gemv(alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True) # overwrite_y=self.inplace)
out_storage[0][0] = gemv(alpha, A.T, x, beta, y,
overwrite_y=self.inplace, trans=True)
else: else:
out = numpy.dot(A, x) out = numpy.dot(A, x)
if alpha != 1: if alpha != 1:
...@@ -231,6 +244,7 @@ class Gemv(Op): ...@@ -231,6 +244,7 @@ class Gemv(Op):
gemv_no_inplace = Gemv(inplace=False) gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True) gemv_inplace = Gemv(inplace=True)
class Ger(Op): class Ger(Op):
""" """
BLAS defines general rank-1 update GER as A <- A + alpha x y' BLAS defines general rank-1 update GER as A <- A + alpha x y'
...@@ -245,12 +259,13 @@ class Ger(Op): ...@@ -245,12 +259,13 @@ class Ger(Op):
and override the make_thunk() method to use Scipy and C respectively. and override the make_thunk() method to use Scipy and C respectively.
""" """
def __init__(self, destructive): def __init__(self, destructive):
self.destructive=destructive self.destructive = destructive
if destructive: if destructive:
self.destroy_map={0:[0]} self.destroy_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self)==type(other) and self.destructive == other.destructive return (type(self) == type(other) and
self.destructive == other.destructive)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.destructive) return hash(type(self)) ^ hash(self.destructive)
...@@ -299,30 +314,36 @@ class Ger(Op): ...@@ -299,30 +314,36 @@ class Ger(Op):
ger = Ger(destructive=False) ger = Ger(destructive=False)
ger_destructive = Ger(destructive=True) ger_destructive = Ger(destructive=True)
def default_blas_ldflags(): def default_blas_ldflags():
try: try:
# If we are in a EPD installation, mkl is available # If we are in a EPD installation, mkl is available
blas_info = numpy.distutils.__config__.blas_opt_info
if "EPD" in sys.version: if "EPD" in sys.version:
if sys.platform == 'win32': if sys.platform == 'win32':
return ' '.join( return ' '.join(
['-L%s' % os.path.join(sys.prefix, "Scripts")] + ['-L%s' % os.path.join(sys.prefix, "Scripts")] +
# Why on Windows, the library used are not the # Why on Windows, the library used are not the
# same as what is in # same as what is in
# numpy.distutils.__config__.blas_opt_info['libraries']? # blas_info['libraries']?
['-l%s' % l for l in ["mk2_core", "mk2_intel_thread", ['-l%s' % l for l in ["mk2_core", "mk2_intel_thread",
"mk2_rt"]]) "mk2_rt"]])
return ' '.join( return ' '.join(
['-L%s' % os.path.join(sys.prefix, "lib")] + ['-L%s' % os.path.join(sys.prefix, "lib")] +
['-l%s' % l for l in numpy.distutils.__config__.blas_opt_info['libraries']]) ['-l%s' % l for l in blas_info['libraries']])
#if numpy was linked with library that are not installed, we can't reuse them. #if numpy was linked with library that are not installed, we
if all(not os.path.exists(dir) for dir in numpy.distutils.__config__.blas_opt_info['library_dirs']): #can't reuse them.
if all(not os.path.exists(dir) for dir in blas_info['library_dirs']):
return "-lblas" return "-lblas"
return ' '.join( return ' '.join(
#TODO: the Gemm op below should separate the -L and -l arguments into the two callbacks that CLinker uses for that stuff. #TODO: the Gemm op below should separate the
# for now, we just pass the whole ldflags as the -l options part. # -L and -l arguments into the two callbacks
['-L%s'%l for l in numpy.distutils.__config__.blas_opt_info['library_dirs']] + # that CLinker uses for that stuff. for now,
['-l%s'%l for l in numpy.distutils.__config__.blas_opt_info['libraries']]) # we just pass the whole ldflags as the -l
# ['-I%s'%l for l in numpy.distutils.__config__.blas_opt_info['include_dirs']]) # options part.
['-L%s' % l for l in blas_info['library_dirs']] +
['-l%s' % l for l in blas_info['libraries']])
# ['-I%s' % l for l in blas_info['include_dirs']])
except KeyError: except KeyError:
return "-lblas" return "-lblas"
...@@ -330,23 +351,27 @@ AddConfigVar('blas.ldflags', ...@@ -330,23 +351,27 @@ AddConfigVar('blas.ldflags',
"lib[s] to include for [Fortran] level-3 blas implementation", "lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags())) StrParam(default_blas_ldflags()))
@utils.memoize @utils.memoize
def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
"""Return a list of libraries against which an Op's object file should be """Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation. linked to benefit from a BLAS implementation.
Default: ['blas'], but configuration variable config.blas.ldflags overrides this. Default: ['blas'], but configuration variable config.blas.ldflags
overrides this.
""" """
rval = [] rval = []
if libs_dir: if libs_dir:
found_dyn=False found_dyn = False
dirs = [x[2:] for x in config.blas.ldflags.split() if x.startswith('-L')] dirs = [x[2:] for x in config.blas.ldflags.split()
if x.startswith('-L')]
l = ldflags() l = ldflags()
for d in dirs: for d in dirs:
for f in os.listdir(d): for f in os.listdir(d):
if f.endswith('.so') or f.endswith('.dylib') or f.endswith('.dll'): if (f.endswith('.so') or f.endswith('.dylib') or
if any([f.find(ll)>=0 for ll in l]): f.endswith('.dll')):
found_dyn=True if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True
if not found_dyn and dirs: if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the " _logger.warning("We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use " "library_dir of the library we use for blas. If you use "
...@@ -361,17 +386,21 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): ...@@ -361,17 +386,21 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
if libs_dir and t1 == 'L': if libs_dir and t1 == 'L':
rval.append(t[2:]) rval.append(t[2:])
elif include_dir and t1 == 'I': elif include_dir and t1 == 'I':
raise ValueError('Include dirs are not used for blas. We disable this as this can hide other headers and this is not wanted.', t) raise ValueError('Include dirs are not used for blas. We disable'
' this as this can hide other headers and this'
' is not wanted.', t)
rval.append(t[2:]) rval.append(t[2:])
elif libs and t1=='l': # example -lmkl elif libs and t1 == 'l': # example -lmkl
rval.append(t[2:]) rval.append(t[2:])
elif flags and t1 not in ['L','I','l']: # example -openmp elif flags and t1 not in ['L', 'I', 'l']: # example -openmp
rval.append(t) rval.append(t)
elif flags and t1 == 'L': elif flags and t1 == 'L':
#to find it when we load the compiled op if the env of the used is not well configured. #to find it when we load the compiled op if the env of the
rval.append('-Wl,-rpath,'+t[2:]) #used is not well configured.
rval.append('-Wl,-rpath,' + t[2:])
return rval return rval
class GemmRelated(Op): class GemmRelated(Op):
"""Base class for Gemm and Dot22 """Base class for Gemm and Dot22
...@@ -379,10 +408,13 @@ class GemmRelated(Op): ...@@ -379,10 +408,13 @@ class GemmRelated(Op):
""" """
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_support_code(self): def c_support_code(self):
#return cblas_header_text() #return cblas_header_text()
mod_str = """ mod_str = """
...@@ -397,6 +429,7 @@ class GemmRelated(Op): ...@@ -397,6 +429,7 @@ class GemmRelated(Op):
} }
""" """
return blas_header_text() + mod_str return blas_header_text() + mod_str
def c_headers(self): def c_headers(self):
# std.cout doesn't require the '%' symbol to print stuff... # std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff. # so it works much better with python's string-substitution stuff.
...@@ -670,6 +703,7 @@ class GemmRelated(Op): ...@@ -670,6 +703,7 @@ class GemmRelated(Op):
def build_gemm_version(self): def build_gemm_version(self):
return (12,) return (12,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation): """In-place version of matrix-matrix multiplication (with accumulation):
...@@ -681,14 +715,15 @@ class Gemm(GemmRelated): ...@@ -681,14 +715,15 @@ class Gemm(GemmRelated):
b*z + a*dot(x,y) b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z, The difference between the two is that the top form is destructive
whereas the bottom form is not. Gemm works in-place on the storage on z, whereas the bottom form is not. Gemm works in-place on the
associated with z, and the L{Variable} returned by Gemm has a storage that storage associated with z, and the L{Variable} returned by Gemm
will be aliased to the storage of the z argument. Because of this in-place has a storage that will be aliased to the storage of the z
computation, an L{Apply} of this op will destroy the L{Variable} z on argument. Because of this in-place computation, an L{Apply} of
which it operates. (See L{DestructiveOps} for an explanation of what this op will destroy the L{Variable} z on which it operates. (See
destroying means in the context of theano graphs. See L{BlasLapackSupport} for L{DestructiveOps} for an explanation of what destroying means in
more optimized linear algebra operations.) the context of theano graphs. See L{BlasLapackSupport} for more
optimized linear algebra operations.)
""" """
E_rank = 'gemm only works for rank 2' E_rank = 'gemm only works for rank 2'
...@@ -698,18 +733,20 @@ class Gemm(GemmRelated): ...@@ -698,18 +733,20 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes' E_float = 'gemm requires floating-point dtypes'
def __init__(self, inplace): def __init__(self, inplace):
self.__setstate__({'inplace':inplace}) self.__setstate__({'inplace': inplace})
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)\ return (type(self) == type(other) and
and self.inplace == other.inplace) self.inplace == other.inplace)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.inplace) return hash(type(self)) ^ hash(self.inplace)
def __str__(self): def __str__(self):
if self.inplace: inplace_str = 'inplace' if self.inplace:
else: inplace_str = 'no_inplace' inplace_str = 'inplace'
else:
inplace_str = 'no_inplace'
return '%s{%s}' % (self.__class__.__name__, inplace_str) return '%s{%s}' % (self.__class__.__name__, inplace_str)
def __setstate__(self, dct): def __setstate__(self, dct):
...@@ -727,9 +764,11 @@ class Gemm(GemmRelated): ...@@ -727,9 +764,11 @@ class Gemm(GemmRelated):
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs) inputs = map(T.as_tensor_variable, inputs)
if len(inputs) != 5: if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs))) raise TypeError(
"Wrong number of inputs for %s (expected 5, got %s)" %
(self, len(inputs)))
z, a, x, y, b = inputs z, a, x, y, b = inputs
zr, xr, yr = [set(view_roots(i)) for i in z,x,y] zr, xr, yr = [set(view_roots(i)) for i in z, x, y]
# TODO: justify / delete # TODO: justify / delete
if zr.intersection(xr): if zr.intersection(xr):
...@@ -767,26 +806,26 @@ class Gemm(GemmRelated): ...@@ -767,26 +806,26 @@ class Gemm(GemmRelated):
if not self.inplace: if not self.inplace:
z = z.copy() # the original z will not be changed z = z.copy() # the original z will not be changed
if z.shape == (): if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y)) z.itemset(z * a + b * numpy.dot(x, y))
zout[0] = z zout[0] = z
else: else:
if b == 0.0: if b == 0.0:
if a == 1.0: if a == 1.0:
z[:] = numpy.dot(x,y) z[:] = numpy.dot(x, y)
elif a == -1.0: elif a == -1.0:
z[:] = -numpy.dot(x,y) z[:] = -numpy.dot(x, y)
else: else:
z[:] = a * numpy.dot(x,y) z[:] = a * numpy.dot(x, y)
elif b == 1.0: elif b == 1.0:
if a == 1.0: if a == 1.0:
z += numpy.dot(x,y) z += numpy.dot(x, y)
elif a == -1.0: elif a == -1.0:
z -= numpy.dot(x,y) z -= numpy.dot(x, y)
else: else:
z += a * numpy.dot(x,y) z += a * numpy.dot(x, y)
else: else:
z *= b z *= b
z += a * numpy.dot(x,y) z += a * numpy.dot(x, y)
zout[0] = z zout[0] = z
setup_z_Nz_Sz_inplace = """ setup_z_Nz_Sz_inplace = """
...@@ -812,10 +851,12 @@ class Gemm(GemmRelated): ...@@ -812,10 +851,12 @@ class Gemm(GemmRelated):
npy_intp dims[2]; npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0]; dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1]; dims[1] = %(_z)s->dimensions[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_z)s); %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_z)s);
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]); //fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemm_no_inplace output"); PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_no_inplace output");
%(fail)s %(fail)s
} }
} }
...@@ -853,7 +894,8 @@ class Gemm(GemmRelated): ...@@ -853,7 +894,8 @@ class Gemm(GemmRelated):
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype"); PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s %(fail)s
} }
} }
...@@ -880,14 +922,16 @@ class Gemm(GemmRelated): ...@@ -880,14 +922,16 @@ class Gemm(GemmRelated):
#undef REAL #undef REAL
""" """
def c_code(self, node, name, inp, out, sub): #DEBUG def c_code(self, node, name, inp, out, sub):
_z, _a, _x, _y, _b = inp _z, _a, _x, _y, _b = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__) % self.__class__.__name__)
if not config.blas.ldflags: if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name, (_z, _a, _x, _y, _b), (_zout, ), sub) return super(Gemm, self).c_code(node, name,
(_z, _a, _x, _y, _b), (_zout, ),
sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
...@@ -903,6 +947,7 @@ gemm_no_inplace = Gemm(inplace=False) ...@@ -903,6 +947,7 @@ gemm_no_inplace = Gemm(inplace=False)
pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace')) pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace')) pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
if maxclients is not None: if maxclients is not None:
retval = (len(node.clients) <= maxclients) retval = (len(node.clients) <= maxclients)
...@@ -939,17 +984,21 @@ def _as_scalar(res, dtype=None): ...@@ -939,17 +984,21 @@ def _as_scalar(res, dtype=None):
return rval return rval
def _is_real_matrix(res): def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \ return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 2 \ and res.type.ndim == 2 \
and res.type.broadcastable[0] == False \ and res.type.broadcastable[0] == False \
and res.type.broadcastable[1] == False #cope with tuple vs. list and res.type.broadcastable[1] == False # cope with tuple vs. list
def _is_real_vector(res): def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \ return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \ and res.type.ndim == 1 \
and res.type.broadcastable[0] == False and res.type.broadcastable[0] == False
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M) #EXPRESSION: (beta * L) + (alpha * M)
...@@ -990,7 +1039,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -990,7 +1039,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
rval = [g.dimshuffle()] rval = [g.dimshuffle()]
return rval return rval
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
if False and res_is_a(M, gemm_no_inplace, 1): if False and res_is_a(M, gemm_no_inplace, 1):
...@@ -1000,16 +1048,20 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -1000,16 +1048,20 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'GEMM', G, L #print 'GEMM', G, L
if res_is_a(G, _dot22, 1): if res_is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(dot(x,y), a, u, v, b))) #EXPRESSION: (beta * L) +
# (alpha * (gemm_no_inplace(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v))))) #EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) +
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v)) # (a * dot(u, v)))))
rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)] #EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) +
# (alpha * a * dot(u, v))
rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta),
alpha * a, u, v, 1.0)]
return rval return rval
if (G is L): if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm_no_inplace(L, alpha*a, u, v, alpha * b + beta)] rval = [gemm_no_inplace(L, alpha * a, u, v, alpha * b + beta)]
return rval return rval
if (1.0 != alpha): if (1.0 != alpha):
#at the very least, move the alpha inside the gemm_no_inplace #at the very least, move the alpha inside the gemm_no_inplace
...@@ -1017,7 +1069,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -1017,7 +1069,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
return rval return rval
if recurse_flip: if recurse_flip:
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False) return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False)
else: else:
return False return False
...@@ -1030,18 +1082,19 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -1030,18 +1082,19 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
if scale == -1: if scale == -1:
return -thing return -thing
else: else:
return scale*thing return scale * thing
try: try:
r.type.broadcastable r.type.broadcastable
except Exception: except Exception:
return None return None
if ((r.type.ndim not in (1, 2)) or if ((r.type.ndim not in (1, 2)) or
r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')): r.type.dtype not in ('float32', 'float64',
'complex64', 'complex128')):
rval.append(scaled(r)) rval.append(scaled(r))
return rval return rval
if maxclients and len(getattr(r,'clients',[])) > maxclients: if maxclients and len(getattr(r, 'clients', [])) > maxclients:
rval.append((scale, r)) rval.append((scale, r))
return rval return rval
...@@ -1074,32 +1127,35 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -1074,32 +1127,35 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
matrices.append(i) matrices.append(i)
else: else:
# just put the original arguments as in the base case # just put the original arguments as in the base case
rval.append((scale,r)) rval.append((scale, r))
return rval return rval
if len(matrices)==1: if len(matrices) == 1:
assert len(vectors)==0 assert len(vectors) == 0
m = matrices[0] m = matrices[0]
if len(scalars) == 0: if len(scalars) == 0:
_gemm_canonicalize(m, scale, rval, 1) _gemm_canonicalize(m, scale, rval, 1)
elif len(scalars) == 1: elif len(scalars) == 1:
_gemm_canonicalize(m, scaled(scalars[0]), rval, 1) _gemm_canonicalize(m, scaled(scalars[0]), rval, 1)
else: else:
_gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1) _gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]),
elif len(vectors)==1: rval, 1)
assert len(matrices)==0 elif len(vectors) == 1:
assert len(matrices) == 0
v = vectors[0] v = vectors[0]
if len(scalars) == 0: if len(scalars) == 0:
_gemm_canonicalize(v, scale, rval, 1) _gemm_canonicalize(v, scale, rval, 1)
elif len(scalars) == 1: elif len(scalars) == 1:
_gemm_canonicalize(v, scaled(scalars[0]), rval, 1) _gemm_canonicalize(v, scaled(scalars[0]), rval, 1)
else: else:
_gemm_canonicalize(v, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1) _gemm_canonicalize(v, T.mul(scaled(scalars[0]),
else: #lets not open this up *scalars[1:]), rval, 1)
rval.append((scale,r)) else: # lets not open this up
rval.append((scale, r))
else: else:
rval.append((scale,r)) rval.append((scale, r))
return rval return rval
def _factor_canonicalized(lst): def _factor_canonicalized(lst):
# remove duplicates from canonicalized list # remove duplicates from canonicalized list
...@@ -1116,17 +1172,17 @@ def _factor_canonicalized(lst): ...@@ -1116,17 +1172,17 @@ def _factor_canonicalized(lst):
# except TypeError: # except TypeError:
# print e, type(e) # print e, type(e)
i = 0 i = 0
while i < len(lst)-1: while i < len(lst) - 1:
try: try:
s_i,M_i = lst[i] s_i, M_i = lst[i]
except Exception: except Exception:
i += 1 i += 1
continue continue
j = i+1 j = i + 1
while j < len(lst): while j < len(lst):
try: try:
s_j,M_j = lst[j] s_j, M_j = lst[j]
except Exception: except Exception:
j += 1 j += 1
continue continue
...@@ -1137,9 +1193,10 @@ def _factor_canonicalized(lst): ...@@ -1137,9 +1193,10 @@ def _factor_canonicalized(lst):
del lst[j] del lst[j]
else: else:
j += 1 j += 1
i+=1 i += 1
return lst return lst
def _gemm_from_factored_list(lst): def _gemm_from_factored_list(lst):
"""Returns None, or a list to replace node.outputs """Returns None, or a list to replace node.outputs
""" """
...@@ -1171,7 +1228,7 @@ def _gemm_from_factored_list(lst): ...@@ -1171,7 +1228,7 @@ def _gemm_from_factored_list(lst):
for i in xrange(len(lst) - 1): for i in xrange(len(lst) - 1):
s_i, M_i = lst[i] s_i, M_i = lst[i]
for j in xrange(i+1, len(lst)): for j in xrange(i + 1, len(lst)):
s_j, M_j = lst[j] s_j, M_j = lst[j]
if M_i.type != M_j.type: if M_i.type != M_j.type:
...@@ -1183,15 +1240,19 @@ def _gemm_from_factored_list(lst): ...@@ -1183,15 +1240,19 @@ def _gemm_from_factored_list(lst):
#print 'GOT IT', gemm_of_sM_list #print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list: if gemm_of_sM_list:
def item_to_var(t): def item_to_var(t):
try: s,M = t try:
except Exception: return t s, M = t
if s == 1: return M except Exception:
if s == -1: return -M return t
return s*M if s == 1:
return M
if s == -1:
return -M
return s * M
assert len(gemm_of_sM_list) == 1 assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input) add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i,j)] for k, input in enumerate(lst) if k not in (i, j)]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: if len(add_inputs) > 1:
rval = [T.add(*add_inputs)] rval = [T.add(*add_inputs)]
...@@ -1200,11 +1261,14 @@ def _gemm_from_factored_list(lst): ...@@ -1200,11 +1261,14 @@ def _gemm_from_factored_list(lst):
#print "RETURNING GEMM THIGN", rval #print "RETURNING GEMM THIGN", rval
return rval return rval
def _gemm_from_node2(node): def _gemm_from_node2(node):
""" """
:todo: In many expressions, there are many ways to turn it into a gemm. For example :todo: In many expressions, there are many ways to turn it into a
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm gemm. For example dot(a,b) + c + d. This function should
causes a cycle in the graph, then another application of gemm can be tried. return all of them, so that if one version of gemm causes a
cycle in the graph, then another application of gemm can be
tried.
""" """
lst = [] lst = []
...@@ -1214,10 +1278,11 @@ def _gemm_from_node2(node): ...@@ -1214,10 +1278,11 @@ def _gemm_from_node2(node):
lst = _factor_canonicalized(lst) lst = _factor_canonicalized(lst)
rval = _gemm_from_factored_list(lst) rval = _gemm_from_factored_list(lst)
# It can happen that _factor_canonicalized and _gemm_from_factored_list # It can happen that _factor_canonicalized and
# return a node with an incorrect type. This happens in particular when # _gemm_from_factored_list return a node with an incorrect
# one of the scalar factors forces the upcast of the whole expression. # type. This happens in particular when one of the scalar
# In that case, we simply skip that candidate for Gemm. This was # factors forces the upcast of the whole expression. In that
# case, we simply skip that candidate for Gemm. This was
# discussed in # discussed in
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket. # but never made it into a trac ticket.
...@@ -1225,6 +1290,7 @@ def _gemm_from_node2(node): ...@@ -1225,6 +1290,7 @@ def _gemm_from_node2(node):
if rval and (rval[0].type == node.outputs[0].type): if rval and (rval[0].type == node.outputs[0].type):
return rval return rval
class GemmOptimizer(Optimizer): class GemmOptimizer(Optimizer):
"""Graph optimizer for inserting Gemm operations""" """Graph optimizer for inserting Gemm operations"""
def __init__(self): def __init__(self):
...@@ -1343,6 +1409,7 @@ class Dot22(GemmRelated): ...@@ -1343,6 +1409,7 @@ class Dot22(GemmRelated):
_dot22 = Dot22() _dot22 = Dot22()
@local_optimizer([T.dot]) @local_optimizer([T.dot])
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
# This works for tensor.outer too because basic.outer is a macro that # This works for tensor.outer too because basic.outer is a macro that
...@@ -1350,10 +1417,11 @@ def local_dot_to_dot22(node): ...@@ -1350,10 +1417,11 @@ def local_dot_to_dot22(node):
if node.op != T.dot: if node.op != T.dot:
return return
x,y = node.inputs x, y = node.inputs
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match # TODO: upcast one so the types match
_logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type) _logger.info('Not optimizing dot with inputs %s %s %s %s',
x, y, x.type, y.type)
return return
if y.type.dtype.startswith('float') or y.type.dtype.startswith('complex'): if y.type.dtype.startswith('float') or y.type.dtype.startswith('complex'):
...@@ -1362,15 +1430,18 @@ def local_dot_to_dot22(node): ...@@ -1362,15 +1430,18 @@ def local_dot_to_dot22(node):
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
if x.ndim == 2 and y.ndim == 1: if x.ndim == 2 and y.ndim == 1:
#print "local_dot_to_dot22: MV" #print "local_dot_to_dot22: MV"
return [_dot22(x, y.dimshuffle(0,'x')).dimshuffle(0)] return [_dot22(x, y.dimshuffle(0, 'x')).dimshuffle(0)]
if x.ndim == 1 and y.ndim == 2: if x.ndim == 1 and y.ndim == 2:
#print "local_dot_to_dot22: VM" #print "local_dot_to_dot22: VM"
return [_dot22(x.dimshuffle('x',0), y).dimshuffle(1)] return [_dot22(x.dimshuffle('x', 0), y).dimshuffle(1)]
if x.ndim == 1 and y.ndim == 1: if x.ndim == 1 and y.ndim == 1:
#print "local_dot_to_dot22: VV" #print "local_dot_to_dot22: VV"
return [_dot22(x.dimshuffle('x',0), y.dimshuffle(0,'x')).dimshuffle()] return [_dot22(x.dimshuffle('x', 0),
y.dimshuffle(0, 'x')).dimshuffle()]
_logger.info('Not optimizing dot with inputs %s %s %s %s',
x, y, x.type, y.type)
_logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type)
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
def local_inplace_gemm(node): def local_inplace_gemm(node):
...@@ -1383,11 +1454,13 @@ def local_inplace_gemv(node): ...@@ -1383,11 +1454,13 @@ def local_inplace_gemv(node):
if node.op == gemv_no_inplace: if node.op == gemv_no_inplace:
return [gemv_inplace(*node.inputs)] return [gemv_inplace(*node.inputs)]
@local_optimizer([ger]) @local_optimizer([ger])
def local_inplace_ger(node): def local_inplace_ger(node):
if node.op == ger: if node.op == ger:
return [ger_destructive(*node.inputs)] return [ger_destructive(*node.inputs)]
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
def local_gemm_to_gemv(node): def local_gemm_to_gemv(node):
"""GEMM acting on row or column matrices -> GEMV """GEMM acting on row or column matrices -> GEMV
...@@ -1401,6 +1474,7 @@ def local_gemm_to_gemv(node): ...@@ -1401,6 +1474,7 @@ def local_gemm_to_gemv(node):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b) r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, 'x')] return [r.dimshuffle(0, 'x')]
@local_optimizer([gemm_no_inplace]) @local_optimizer([gemm_no_inplace])
def local_gemm_to_ger(node): def local_gemm_to_ger(node):
"""GEMM computing an outer-product -> GER """GEMM computing an outer-product -> GER
...@@ -1481,8 +1555,8 @@ blas_optdb = SequenceDB() ...@@ -1481,8 +1555,8 @@ blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5) # run after numerical stability optimizations (1.5)
optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run') optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# run before specialize (2.0) because specialize is basically a free-for-all that makes the # run before specialize (2.0) because specialize is basically a
# graph crazy. # free-for-all that makes the graph crazy.
blas_optdb.register('local_dot_to_dot22', blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5), EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
...@@ -1502,7 +1576,8 @@ blas_optdb.register('local_gemm_to_gemv', ...@@ -1502,7 +1576,8 @@ blas_optdb.register('local_gemm_to_gemv',
# After destroyhandler is in but before we try to make elemwise things inplace # After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71) # Also, need to make the gemm optimisation(step 70) happen before the
# fusion of elemwise(step 71)
blas_opt_inplace = EquilibriumOptimizer( blas_opt_inplace = EquilibriumOptimizer(
[local_inplace_gemm, local_inplace_gemv, local_inplace_ger], [local_inplace_gemm, local_inplace_gemv, local_inplace_ger],
failure_callback=EquilibriumOptimizer.warn_inplace, failure_callback=EquilibriumOptimizer.warn_inplace,
...@@ -1538,7 +1613,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1538,7 +1613,7 @@ class Dot22Scalar(GemmRelated):
bz = [x.type.broadcastable[0], y.type.broadcastable[1]] bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y,a], outputs) return Apply(self, [x, y, a], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y, scalar = inp x, y, scalar = inp
...@@ -1546,7 +1621,8 @@ class Dot22Scalar(GemmRelated): ...@@ -1546,7 +1621,8 @@ class Dot22Scalar(GemmRelated):
try: try:
z[0] = numpy.asarray(scalar * numpy.dot(x, y)) z[0] = numpy.asarray(scalar * numpy.dot(x, y))
except ValueError, e: except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we
# mean to add that
e.args = e.args + (x.shape, y.shape) e.args = e.args + (x.shape, y.shape)
raise raise
...@@ -1558,7 +1634,8 @@ class Dot22Scalar(GemmRelated): ...@@ -1558,7 +1634,8 @@ class Dot22Scalar(GemmRelated):
check_ab_double_or_float = """ check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != PyArray_DOUBLE) if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT)) && (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError,
"type(a) is not double or float"); %(fail)s;}
""" """
case_float_ab_constants = """ case_float_ab_constants = """
...@@ -1579,14 +1656,15 @@ class Dot22Scalar(GemmRelated): ...@@ -1579,14 +1656,15 @@ class Dot22Scalar(GemmRelated):
double b = 0.0; double b = 0.0;
""" """
def c_code(self, node, name, inp, out, sub): #DEBUG def c_code(self, node, name, inp, out, sub):
_x, _y, _a = inp _x, _y, _a = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub) return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
(_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code return full_code
...@@ -1599,21 +1677,29 @@ class Dot22Scalar(GemmRelated): ...@@ -1599,21 +1677,29 @@ class Dot22Scalar(GemmRelated):
_dot22scalar = Dot22Scalar() _dot22scalar = Dot22Scalar()
@local_optimizer([T.mul]) @local_optimizer([T.mul])
def local_dot22_to_dot22scalar(node): def local_dot22_to_dot22scalar(node):
""" """
:note: we upcast the scalar if after the multiplication with the dot this give the same type. :note: we upcast the scalar if after the multiplication with the
.. note: dot this give the same type.
We execute this optimizer after the gemm optimizer. This allow to give more priority to gemm that give more speed up then this optimizer, but allow the gemm optimizer to ignore this op.
.. note: We execute this optimizer after the gemm optimizer. This
allow to give more priority to gemm that give more speed up
then this optimizer, but allow the gemm optimizer to ignore
this op.
TODO: support when we can reorder the mul to generate a dot22scalar or fix the canonizer to merge them(1 mul with multiple inputs)
TODO: support when we can reorder the mul to generate a
dot22scalar or fix the canonizer to merge them(1 mul with multiple
inputs)
""" """
if node.op != T.mul: if node.op != T.mul:
return False return False
i_dot22 = [x.owner and x.owner.op==_dot22 for x in node.inputs] i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs]
if not any(i_dot22): return False # no dot22 if not any(i_dot22):
if i_dot22.count(True)>1: return False # no dot22
if i_dot22.count(True) > 1:
#TODO: try each of them. #TODO: try each of them.
pass pass
#return False #TODO fix #return False #TODO fix
...@@ -1621,34 +1707,39 @@ def local_dot22_to_dot22scalar(node): ...@@ -1621,34 +1707,39 @@ def local_dot22_to_dot22scalar(node):
d = node.inputs[dot22_idx] d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar): if not any(i_scalar):
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs] i_mul = [x.owner and x.owner.op == T.mul for x in node.inputs]
if not any(i_mul): if not any(i_mul):
#no scalar in input and no multiplication #no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph. #if their was a multiplication we couls reorder the graph
#by the associativity of the graph.
return False return False
#maybe we can reorder the graph as this mul have a mul in input. #maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together. #The canonizer should have merged those mul together.
#We support only 1 additional level of mul. #We support only 1 additional level of mul.
mul_idx = i_mul.index(True)#we take the first mul! mul_idx = i_mul.index(True) # we take the first mul!
m = node.inputs[mul_idx] m = node.inputs[mul_idx]
if len(m.owner.inputs)==2 and any([_as_scalar(x, dtype=d.dtype) for x in m.owner.inputs]): if len(m.owner.inputs) == 2 and any([_as_scalar(x, dtype=d.dtype)
for x in m.owner.inputs]):
scalar_idx = -1 scalar_idx = -1
for i,x in enumerate(m.owner.inputs): for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(x.type.dtype,d.type.dtype) if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype)
== d.type.dtype): == d.type.dtype):
scalar_idx = i scalar_idx = i
break break
if scalar_idx < 0: if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type ' _logger.info('Not optimizing dot22 with inputs %s %s, as the'
'of the scalar cannot be upcasted to the matrix type', ' type of the scalar cannot be upcasted to the'
' matrix type',
node.inputs, [x.type for x in node.inputs]) node.inputs, [x.type for x in node.inputs])
return False return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) a = T.cast(_as_scalar(m.owner.inputs[scalar_idx],
dtype=d.dtype), d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# What about the other inputs to the original node that were # What about the other inputs to the original node that were
# neither part of the dot22 or this mul? # neither part of the dot22 or this mul?
...@@ -1657,7 +1748,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1657,7 +1748,7 @@ def local_dot22_to_dot22scalar(node):
assert all((i in (dot22_idx, mul_idx)) assert all((i in (dot22_idx, mul_idx))
for i in xrange(len(node.inputs))) for i in xrange(len(node.inputs)))
return [T.mul(m.owner.inputs[1-i],dot)] return [T.mul(m.owner.inputs[1 - i], dot)]
elif m.owner and m.owner.op == T.mul: elif m.owner and m.owner.op == T.mul:
_logger.info('Not optimizing dot22 with inputs %s %s %s %s. ' _logger.info('Not optimizing dot22 with inputs %s %s %s %s. '
'we need to check in a recursive way in the mul if we can ' 'we need to check in a recursive way in the mul if we can '
...@@ -1667,9 +1758,9 @@ def local_dot22_to_dot22scalar(node): ...@@ -1667,9 +1758,9 @@ def local_dot22_to_dot22scalar(node):
return False return False
scalar_idx = -1 scalar_idx = -1
for i,x in enumerate(node.inputs): for i, x in enumerate(node.inputs):
if (i_scalar[i] is not None if (i_scalar[i] is not None
and (theano.scalar.upcast(x.type.dtype,d.type.dtype) and (theano.scalar.upcast(x.type.dtype, d.type.dtype)
== d.type.dtype)): == d.type.dtype)):
scalar_idx = i scalar_idx = i
break break
...@@ -1689,15 +1780,17 @@ def local_dot22_to_dot22scalar(node): ...@@ -1689,15 +1780,17 @@ def local_dot22_to_dot22scalar(node):
if len(o) == 0: if len(o) == 0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else: else:
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)] return [T.mul(_dot22scalar(d.owner.inputs[0],
d.owner.inputs[1], a), *o)]
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar #must happen after gemm as the gemm optimizer don't understant
#dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar', blas_optdb.register('local_dot22_to_dot22scalar',
EquilibriumOptimizer([local_dot22_to_dot22scalar ], max_use_ratio=5), EquilibriumOptimizer([local_dot22_to_dot22scalar], max_use_ratio=5),
11, 'fast_run') 11, 'fast_run')
from opt import register_specialize, register_canonicalize #from opt import register_specialize, register_canonicalize
#@register_specialize #@register_specialize
@local_optimizer([]) @local_optimizer([])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论