提交 d18c322f authored 作者: Frederic's avatar Frederic

pep8

上级 927ac9a4
......@@ -39,12 +39,12 @@ Dot22Scalar is a GEMM where b=0 and Z is allocated every time.
Gemm is a GEMM in all its generality.
In the future we can refactor the GemmRelated, Gemm, Dot22 and
Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a normal Gemm, but
with an additional configuration variable that says to ignore the input Z.
Setting that configuration variable to True would make Gemm2 equivalent to the
current Dot22 and Dot22Scalar. This would make the file a lot easier to read,
and save a few hundred lines of library, to say nothing of testing and
documentation.
Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a
normal Gemm, but with an additional configuration variable that says
to ignore the input Z. Setting that configuration variable to True
would make Gemm2 equivalent to the current Dot22 and Dot22Scalar.
This would make the file a lot easier to read, and save a few hundred
lines of library, to say nothing of testing and documentation.
GEMV: Gemv
......@@ -109,21 +109,24 @@ This is complicated, done in GemmOptimizer.
Identify Dot22Scalar from Dot22
-------------------------------
Dot22 Ops that remain after the GemmOptimizer is done have not qualified as GEMM
Ops. Still they might be scaled by a factor, in which case we use Dot22Scalar
which is like Gemm, but without the b and the Z. In the future it would be good
to merge this into the GemmOptimizer.
Dot22 Ops that remain after the GemmOptimizer is done have not
qualified as GEMM Ops. Still they might be scaled by a factor, in
which case we use Dot22Scalar which is like Gemm, but without the b
and the Z. In the future it would be good to merge this into the
GemmOptimizer.
Specialize Gemm to Gemv
-----------------------
If arguments to GEMM are dimshuffled vectors, then we can use GEMV instead. This
optimization is `local_gemm_to_gemv`.
If arguments to GEMM are dimshuffled vectors, then we can use GEMV
instead. This optimization is `local_gemm_to_gemv`.
"""
import logging, copy, os, sys
import copy
import logging
import os
import sys
import numpy
import numpy.distutils
......@@ -137,7 +140,7 @@ from theano.compile.mode import optdb
from theano.gof.python25 import all, any
import theano.scalar
import basic as T
from theano.tensor.blas_headers import blas_header_text #, cblas_header_text
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.opt import local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas')
......@@ -146,16 +149,17 @@ try:
import scipy.linalg.blas
_have_fblas = True
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
numpy.dtype('float32'): scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'): scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'): scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'): scipy.linalg.blas.fblas.zgemv,
}
except ImportError, e:
_have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas.fblas. '
'Falling back on slower implementations (%s)', str(e))
class Gemv(Op):
"""
expression is beta * y + alpha * A x
......@@ -166,12 +170,12 @@ class Gemv(Op):
output is a vector that can be inplace on y
"""
def __init__(self, inplace):
self.inplace=inplace
self.inplace = inplace
if inplace:
self.destroy_map={0:[0]}
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self)==type(other) and self.inplace == other.inplace
return type(self) == type(other) and self.inplace == other.inplace
def __str__(self):
if self.inplace:
......@@ -189,41 +193,50 @@ class Gemv(Op):
alpha = T.as_tensor_variable(alpha)
beta = T.as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError('Gemv requires matching dtypes', (y.dtype, A.dtype, x.dtype))
if A.ndim != 2: raise TypeError('gemv requires matrix for A', A.type)
if x.ndim != 1: raise TypeError('gemv requires vector for x', x.type)
if y.ndim != 1: raise TypeError('gemv requires vector for y', y.type)
raise TypeError('Gemv requires matching dtypes',
(y.dtype, A.dtype, x.dtype))
if A.ndim != 2:
raise TypeError('gemv requires matrix for A', A.type)
if x.ndim != 1:
raise TypeError('gemv requires vector for x', x.type)
if y.ndim != 1:
raise TypeError('gemv requires vector for y', y.type)
if y.broadcastable[0] != A.broadcastable[0]:
raise TypeError('broadcastable mismatch between y and A', (y.type, A.type))
# The following is not grounds for error
# because as long as sizes are 1 at time of perform() there is no problem
raise TypeError('broadcastable mismatch between y and A',
(y.type, A.type))
# The following is not grounds for error because as long as
# sizes are 1 at time of perform() there is no problem
#if x.broadcastable[0] != A.broadcastable[1]:
#raise TypeError('broadcastable mismatch between x and A', (x.type, A.type))
#raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
if _have_fblas and y.shape[0]!=0 and x.shape[0]!=0:
if _have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))#
% (y.shape, A.shape, x.shape))
#Here I suppose that A is in c order. If we don't make it explicitly
# as fortran order, scipy 0.7.2 seam to create a copy in fortran
# order instead of just reshaping it and using the trans flag.
#Here I suppose that A is in c order. If we don't make it
# explicitly as fortran order, scipy 0.7.2 seam to create
# a copy in fortran order instead of just reshaping it
# and using the trans flag.
#If A is already in fortran order, make it in c order and using the
# trans flag don't seam to cause slowdown.
#out_storage[0][0] = gemv(alpha, A, x, beta, y, overwrite_y=self.inplace)
out_storage[0][0] = gemv(alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True)
#out_storage[0][0] = gemv(alpha, A, x, beta, y,
# overwrite_y=self.inplace)
out_storage[0][0] = gemv(alpha, A.T, x, beta, y,
overwrite_y=self.inplace, trans=True)
else:
out = numpy.dot(A, x)
if alpha != 1:
out *= alpha
if beta != 1:
out += beta * y
out += beta * y
else:
out += y
out_storage[0][0] = numpy.asarray(out, dtype=y.dtype)
......@@ -231,6 +244,7 @@ class Gemv(Op):
gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True)
class Ger(Op):
"""
BLAS defines general rank-1 update GER as A <- A + alpha x y'
......@@ -245,12 +259,13 @@ class Ger(Op):
and override the make_thunk() method to use Scipy and C respectively.
"""
def __init__(self, destructive):
self.destructive=destructive
self.destructive = destructive
if destructive:
self.destroy_map={0:[0]}
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self)==type(other) and self.destructive == other.destructive
return (type(self) == type(other) and
self.destructive == other.destructive)
def __hash__(self):
return hash(type(self)) ^ hash(self.destructive)
......@@ -299,30 +314,36 @@ class Ger(Op):
ger = Ger(destructive=False)
ger_destructive = Ger(destructive=True)
def default_blas_ldflags():
try:
# If we are in a EPD installation, mkl is available
blas_info = numpy.distutils.__config__.blas_opt_info
if "EPD" in sys.version:
if sys.platform == 'win32':
return ' '.join(
['-L%s' % os.path.join(sys.prefix, "Scripts")] +
# Why on Windows, the library used are not the
# same as what is in
# numpy.distutils.__config__.blas_opt_info['libraries']?
# blas_info['libraries']?
['-l%s' % l for l in ["mk2_core", "mk2_intel_thread",
"mk2_rt"]])
return ' '.join(
['-L%s' % os.path.join(sys.prefix, "lib")] +
['-l%s' % l for l in numpy.distutils.__config__.blas_opt_info['libraries']])
#if numpy was linked with library that are not installed, we can't reuse them.
if all(not os.path.exists(dir) for dir in numpy.distutils.__config__.blas_opt_info['library_dirs']):
['-l%s' % l for l in blas_info['libraries']])
#if numpy was linked with library that are not installed, we
#can't reuse them.
if all(not os.path.exists(dir) for dir in blas_info['library_dirs']):
return "-lblas"
return ' '.join(
#TODO: the Gemm op below should separate the -L and -l arguments into the two callbacks that CLinker uses for that stuff.
# for now, we just pass the whole ldflags as the -l options part.
['-L%s'%l for l in numpy.distutils.__config__.blas_opt_info['library_dirs']] +
['-l%s'%l for l in numpy.distutils.__config__.blas_opt_info['libraries']])
# ['-I%s'%l for l in numpy.distutils.__config__.blas_opt_info['include_dirs']])
#TODO: the Gemm op below should separate the
# -L and -l arguments into the two callbacks
# that CLinker uses for that stuff. for now,
# we just pass the whole ldflags as the -l
# options part.
['-L%s' % l for l in blas_info['library_dirs']] +
['-l%s' % l for l in blas_info['libraries']])
# ['-I%s' % l for l in blas_info['include_dirs']])
except KeyError:
return "-lblas"
......@@ -330,23 +351,27 @@ AddConfigVar('blas.ldflags',
"lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags()))
@utils.memoize
def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
"""Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation.
Default: ['blas'], but configuration variable config.blas.ldflags overrides this.
Default: ['blas'], but configuration variable config.blas.ldflags
overrides this.
"""
rval = []
if libs_dir:
found_dyn=False
dirs = [x[2:] for x in config.blas.ldflags.split() if x.startswith('-L')]
found_dyn = False
dirs = [x[2:] for x in config.blas.ldflags.split()
if x.startswith('-L')]
l = ldflags()
for d in dirs:
for f in os.listdir(d):
if f.endswith('.so') or f.endswith('.dylib') or f.endswith('.dll'):
if any([f.find(ll)>=0 for ll in l]):
found_dyn=True
if (f.endswith('.so') or f.endswith('.dylib') or
f.endswith('.dll')):
if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True
if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use "
......@@ -361,17 +386,21 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
if libs_dir and t1 == 'L':
rval.append(t[2:])
elif include_dir and t1 == 'I':
raise ValueError('Include dirs are not used for blas. We disable this as this can hide other headers and this is not wanted.', t)
raise ValueError('Include dirs are not used for blas. We disable'
' this as this can hide other headers and this'
' is not wanted.', t)
rval.append(t[2:])
elif libs and t1=='l': # example -lmkl
elif libs and t1 == 'l': # example -lmkl
rval.append(t[2:])
elif flags and t1 not in ['L','I','l']: # example -openmp
elif flags and t1 not in ['L', 'I', 'l']: # example -openmp
rval.append(t)
elif flags and t1 == 'L':
#to find it when we load the compiled op if the env of the used is not well configured.
rval.append('-Wl,-rpath,'+t[2:])
#to find it when we load the compiled op if the env of the
#used is not well configured.
rval.append('-Wl,-rpath,' + t[2:])
return rval
class GemmRelated(Op):
"""Base class for Gemm and Dot22
......@@ -379,10 +408,13 @@ class GemmRelated(Op):
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def c_support_code(self):
#return cblas_header_text()
mod_str = """
......@@ -397,6 +429,7 @@ class GemmRelated(Op):
}
"""
return blas_header_text() + mod_str
def c_headers(self):
# std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff.
......@@ -670,6 +703,7 @@ class GemmRelated(Op):
def build_gemm_version(self):
return (12,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......@@ -681,14 +715,15 @@ class Gemm(GemmRelated):
b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage
associated with z, and the L{Variable} returned by Gemm has a storage that
will be aliased to the storage of the z argument. Because of this in-place
computation, an L{Apply} of this op will destroy the L{Variable} z on
which it operates. (See L{DestructiveOps} for an explanation of what
destroying means in the context of theano graphs. See L{BlasLapackSupport} for
more optimized linear algebra operations.)
The difference between the two is that the top form is destructive
on z, whereas the bottom form is not. Gemm works in-place on the
storage associated with z, and the L{Variable} returned by Gemm
has a storage that will be aliased to the storage of the z
argument. Because of this in-place computation, an L{Apply} of
this op will destroy the L{Variable} z on which it operates. (See
L{DestructiveOps} for an explanation of what destroying means in
the context of theano graphs. See L{BlasLapackSupport} for more
optimized linear algebra operations.)
"""
E_rank = 'gemm only works for rank 2'
......@@ -698,18 +733,20 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes'
def __init__(self, inplace):
self.__setstate__({'inplace':inplace})
self.__setstate__({'inplace': inplace})
def __eq__(self, other):
return (type(self) == type(other)\
and self.inplace == other.inplace)
return (type(self) == type(other) and
self.inplace == other.inplace)
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
if self.inplace: inplace_str = 'inplace'
else: inplace_str = 'no_inplace'
if self.inplace:
inplace_str = 'inplace'
else:
inplace_str = 'no_inplace'
return '%s{%s}' % (self.__class__.__name__, inplace_str)
def __setstate__(self, dct):
......@@ -727,9 +764,11 @@ class Gemm(GemmRelated):
def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs)
if len(inputs) != 5:
raise TypeError("Wrong number of inputs for %s (expected 5, got %s)" % (self, len(inputs)))
raise TypeError(
"Wrong number of inputs for %s (expected 5, got %s)" %
(self, len(inputs)))
z, a, x, y, b = inputs
zr, xr, yr = [set(view_roots(i)) for i in z,x,y]
zr, xr, yr = [set(view_roots(i)) for i in z, x, y]
# TODO: justify / delete
if zr.intersection(xr):
......@@ -765,28 +804,28 @@ class Gemm(GemmRelated):
assert a.shape == ()
assert b.shape == ()
if not self.inplace:
z = z.copy() # the original z will not be changed
z = z.copy() # the original z will not be changed
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
z.itemset(z * a + b * numpy.dot(x, y))
zout[0] = z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
z[:] = numpy.dot(x, y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
z[:] = -numpy.dot(x, y)
else:
z[:] = a * numpy.dot(x,y)
z[:] = a * numpy.dot(x, y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
z += numpy.dot(x, y)
elif a == -1.0:
z -= numpy.dot(x,y)
z -= numpy.dot(x, y)
else:
z += a * numpy.dot(x,y)
z += a * numpy.dot(x, y)
else:
z *= b
z += a * numpy.dot(x,y)
z += a * numpy.dot(x, y)
zout[0] = z
setup_z_Nz_Sz_inplace = """
......@@ -812,10 +851,12 @@ class Gemm(GemmRelated):
npy_intp dims[2];
dims[0] = %(_z)s->dimensions[0];
dims[1] = %(_z)s->dimensions[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_z)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_z)s);
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemm_no_inplace output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_no_inplace output");
%(fail)s
}
}
......@@ -853,7 +894,8 @@ class Gemm(GemmRelated):
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
......@@ -880,14 +922,16 @@ class Gemm(GemmRelated):
#undef REAL
"""
def c_code(self, node, name, inp, out, sub): #DEBUG
def c_code(self, node, name, inp, out, sub):
_z, _a, _x, _y, _b = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name, (_z, _a, _x, _y, _b), (_zout, ), sub)
return super(Gemm, self).c_code(node, name,
(_z, _a, _x, _y, _b), (_zout, ),
sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
......@@ -903,6 +947,7 @@ gemm_no_inplace = Gemm(inplace=False)
pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace'))
def res_is_a(node, op, maxclients=None):
if maxclients is not None:
retval = (len(node.clients) <= maxclients)
......@@ -939,17 +984,21 @@ def _as_scalar(res, dtype=None):
return rval
def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 2 \
and res.type.broadcastable[0] == False \
and res.type.broadcastable[1] == False #cope with tuple vs. list
and res.type.broadcastable[1] == False # cope with tuple vs. list
def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \
and res.type.broadcastable[0] == False
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
......@@ -990,7 +1039,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
rval = [g.dimshuffle()]
return rval
# this is False'd out because of inadequate testing.
# TODO see ticket #237
if False and res_is_a(M, gemm_no_inplace, 1):
......@@ -1000,16 +1048,20 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'GEMM', G, L
if res_is_a(G, _dot22, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(dot(x,y), a, u, v, b)))
#EXPRESSION: (beta * L) +
# (alpha * (gemm_no_inplace(dot(x,y), a, u, v, b)))
x, y = G.owner.inputs
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) +
# (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) +
# (alpha * a * dot(u, v))
rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta),
alpha * a, u, v, 1.0)]
return rval
if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm_no_inplace(L, alpha*a, u, v, alpha * b + beta)]
rval = [gemm_no_inplace(L, alpha * a, u, v, alpha * b + beta)]
return rval
if (1.0 != alpha):
#at the very least, move the alpha inside the gemm_no_inplace
......@@ -1017,7 +1069,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
return rval
if recurse_flip:
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip = False)
return _beta_L_plus_alpha_M(alpha, M, beta, L, recurse_flip=False)
else:
return False
......@@ -1030,18 +1082,19 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
if scale == -1:
return -thing
else:
return scale*thing
return scale * thing
try:
r.type.broadcastable
except Exception:
return None
if ((r.type.ndim not in (1, 2)) or
r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')):
r.type.dtype not in ('float32', 'float64',
'complex64', 'complex128')):
rval.append(scaled(r))
return rval
if maxclients and len(getattr(r,'clients',[])) > maxclients:
if maxclients and len(getattr(r, 'clients', [])) > maxclients:
rval.append((scale, r))
return rval
......@@ -1074,32 +1127,35 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
matrices.append(i)
else:
# just put the original arguments as in the base case
rval.append((scale,r))
rval.append((scale, r))
return rval
if len(matrices)==1:
assert len(vectors)==0
if len(matrices) == 1:
assert len(vectors) == 0
m = matrices[0]
if len(scalars) == 0:
_gemm_canonicalize(m, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(m, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
elif len(vectors)==1:
assert len(matrices)==0
_gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]),
rval, 1)
elif len(vectors) == 1:
assert len(matrices) == 0
v = vectors[0]
if len(scalars) == 0:
_gemm_canonicalize(v, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(v, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(v, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
else: #lets not open this up
rval.append((scale,r))
_gemm_canonicalize(v, T.mul(scaled(scalars[0]),
*scalars[1:]), rval, 1)
else: # lets not open this up
rval.append((scale, r))
else:
rval.append((scale,r))
rval.append((scale, r))
return rval
def _factor_canonicalized(lst):
# remove duplicates from canonicalized list
......@@ -1116,17 +1172,17 @@ def _factor_canonicalized(lst):
# except TypeError:
# print e, type(e)
i = 0
while i < len(lst)-1:
while i < len(lst) - 1:
try:
s_i,M_i = lst[i]
s_i, M_i = lst[i]
except Exception:
i += 1
continue
j = i+1
j = i + 1
while j < len(lst):
try:
s_j,M_j = lst[j]
s_j, M_j = lst[j]
except Exception:
j += 1
continue
......@@ -1137,9 +1193,10 @@ def _factor_canonicalized(lst):
del lst[j]
else:
j += 1
i+=1
i += 1
return lst
def _gemm_from_factored_list(lst):
"""Returns None, or a list to replace node.outputs
"""
......@@ -1171,7 +1228,7 @@ def _gemm_from_factored_list(lst):
for i in xrange(len(lst) - 1):
s_i, M_i = lst[i]
for j in xrange(i+1, len(lst)):
for j in xrange(i + 1, len(lst)):
s_j, M_j = lst[j]
if M_i.type != M_j.type:
......@@ -1183,15 +1240,19 @@ def _gemm_from_factored_list(lst):
#print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list:
def item_to_var(t):
try: s,M = t
except Exception: return t
if s == 1: return M
if s == -1: return -M
return s*M
try:
s, M = t
except Exception:
return t
if s == 1:
return M
if s == -1:
return -M
return s * M
assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i,j)]
for k, input in enumerate(lst) if k not in (i, j)]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [T.add(*add_inputs)]
......@@ -1200,11 +1261,14 @@ def _gemm_from_factored_list(lst):
#print "RETURNING GEMM THIGN", rval
return rval
def _gemm_from_node2(node):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
:todo: In many expressions, there are many ways to turn it into a
gemm. For example dot(a,b) + c + d. This function should
return all of them, so that if one version of gemm causes a
cycle in the graph, then another application of gemm can be
tried.
"""
lst = []
......@@ -1214,10 +1278,11 @@ def _gemm_from_node2(node):
lst = _factor_canonicalized(lst)
rval = _gemm_from_factored_list(lst)
# It can happen that _factor_canonicalized and _gemm_from_factored_list
# return a node with an incorrect type. This happens in particular when
# one of the scalar factors forces the upcast of the whole expression.
# In that case, we simply skip that candidate for Gemm. This was
# It can happen that _factor_canonicalized and
# _gemm_from_factored_list return a node with an incorrect
# type. This happens in particular when one of the scalar
# factors forces the upcast of the whole expression. In that
# case, we simply skip that candidate for Gemm. This was
# discussed in
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket.
......@@ -1225,6 +1290,7 @@ def _gemm_from_node2(node):
if rval and (rval[0].type == node.outputs[0].type):
return rval
class GemmOptimizer(Optimizer):
"""Graph optimizer for inserting Gemm operations"""
def __init__(self):
......@@ -1343,6 +1409,7 @@ class Dot22(GemmRelated):
_dot22 = Dot22()
@local_optimizer([T.dot])
def local_dot_to_dot22(node):
# This works for tensor.outer too because basic.outer is a macro that
......@@ -1350,10 +1417,11 @@ def local_dot_to_dot22(node):
if node.op != T.dot:
return
x,y = node.inputs
x, y = node.inputs
if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match
_logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type)
_logger.info('Not optimizing dot with inputs %s %s %s %s',
x, y, x.type, y.type)
return
if y.type.dtype.startswith('float') or y.type.dtype.startswith('complex'):
......@@ -1362,15 +1430,18 @@ def local_dot_to_dot22(node):
return [_dot22(*node.inputs)]
if x.ndim == 2 and y.ndim == 1:
#print "local_dot_to_dot22: MV"
return [_dot22(x, y.dimshuffle(0,'x')).dimshuffle(0)]
return [_dot22(x, y.dimshuffle(0, 'x')).dimshuffle(0)]
if x.ndim == 1 and y.ndim == 2:
#print "local_dot_to_dot22: VM"
return [_dot22(x.dimshuffle('x',0), y).dimshuffle(1)]
return [_dot22(x.dimshuffle('x', 0), y).dimshuffle(1)]
if x.ndim == 1 and y.ndim == 1:
#print "local_dot_to_dot22: VV"
return [_dot22(x.dimshuffle('x',0), y.dimshuffle(0,'x')).dimshuffle()]
return [_dot22(x.dimshuffle('x', 0),
y.dimshuffle(0, 'x')).dimshuffle()]
_logger.info('Not optimizing dot with inputs %s %s %s %s',
x, y, x.type, y.type)
_logger.info('Not optimizing dot with inputs %s %s %s %s', x, y, x.type, y.type)
@local_optimizer([gemm_no_inplace])
def local_inplace_gemm(node):
......@@ -1383,11 +1454,13 @@ def local_inplace_gemv(node):
if node.op == gemv_no_inplace:
return [gemv_inplace(*node.inputs)]
@local_optimizer([ger])
def local_inplace_ger(node):
if node.op == ger:
return [ger_destructive(*node.inputs)]
@local_optimizer([gemm_no_inplace])
def local_gemm_to_gemv(node):
"""GEMM acting on row or column matrices -> GEMV
......@@ -1401,6 +1474,7 @@ def local_gemm_to_gemv(node):
r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
return [r.dimshuffle(0, 'x')]
@local_optimizer([gemm_no_inplace])
def local_gemm_to_ger(node):
"""GEMM computing an outer-product -> GER
......@@ -1481,8 +1555,8 @@ blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5)
optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# run before specialize (2.0) because specialize is basically a free-for-all that makes the
# graph crazy.
# run before specialize (2.0) because specialize is basically a
# free-for-all that makes the graph crazy.
blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
......@@ -1502,7 +1576,8 @@ blas_optdb.register('local_gemm_to_gemv',
# After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
# Also, need to make the gemm optimisation(step 70) happen before the
# fusion of elemwise(step 71)
blas_opt_inplace = EquilibriumOptimizer(
[local_inplace_gemm, local_inplace_gemv, local_inplace_ger],
failure_callback=EquilibriumOptimizer.warn_inplace,
......@@ -1538,7 +1613,7 @@ class Dot22Scalar(GemmRelated):
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y,a], outputs)
return Apply(self, [x, y, a], outputs)
def perform(self, node, inp, out):
x, y, scalar = inp
......@@ -1546,7 +1621,8 @@ class Dot22Scalar(GemmRelated):
try:
z[0] = numpy.asarray(scalar * numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
# The error raised by numpy has no shape information, we
# mean to add that
e.args = e.args + (x.shape, y.shape)
raise
......@@ -1558,7 +1634,8 @@ class Dot22Scalar(GemmRelated):
check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
{PyErr_SetString(PyExc_NotImplementedError,
"type(a) is not double or float"); %(fail)s;}
"""
case_float_ab_constants = """
......@@ -1579,14 +1656,15 @@ class Dot22Scalar(GemmRelated):
double b = 0.0;
"""
def c_code(self, node, name, inp, out, sub): #DEBUG
def c_code(self, node, name, inp, out, sub):
_x, _y, _a = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), (_zout, ), sub)
return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
(_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub)
return full_code
......@@ -1599,21 +1677,29 @@ class Dot22Scalar(GemmRelated):
_dot22scalar = Dot22Scalar()
@local_optimizer([T.mul])
def local_dot22_to_dot22scalar(node):
"""
:note: we upcast the scalar if after the multiplication with the dot this give the same type.
.. note:
We execute this optimizer after the gemm optimizer. This allow to give more priority to gemm that give more speed up then this optimizer, but allow the gemm optimizer to ignore this op.
:note: we upcast the scalar if after the multiplication with the
dot this give the same type.
.. note: We execute this optimizer after the gemm optimizer. This
allow to give more priority to gemm that give more speed up
then this optimizer, but allow the gemm optimizer to ignore
this op.
TODO: support when we can reorder the mul to generate a dot22scalar or fix the canonizer to merge them(1 mul with multiple inputs)
TODO: support when we can reorder the mul to generate a
dot22scalar or fix the canonizer to merge them(1 mul with multiple
inputs)
"""
if node.op != T.mul:
return False
i_dot22 = [x.owner and x.owner.op==_dot22 for x in node.inputs]
if not any(i_dot22): return False # no dot22
if i_dot22.count(True)>1:
i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs]
if not any(i_dot22):
return False # no dot22
if i_dot22.count(True) > 1:
#TODO: try each of them.
pass
#return False #TODO fix
......@@ -1621,34 +1707,39 @@ def local_dot22_to_dot22scalar(node):
d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar):
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
i_mul = [x.owner and x.owner.op == T.mul for x in node.inputs]
if not any(i_mul):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
#if their was a multiplication we couls reorder the graph
#by the associativity of the graph.
return False
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
mul_idx = i_mul.index(True)#we take the first mul!
mul_idx = i_mul.index(True) # we take the first mul!
m = node.inputs[mul_idx]
if len(m.owner.inputs)==2 and any([_as_scalar(x, dtype=d.dtype) for x in m.owner.inputs]):
if len(m.owner.inputs) == 2 and any([_as_scalar(x, dtype=d.dtype)
for x in m.owner.inputs]):
scalar_idx = -1
for i,x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype)
== d.type.dtype):
scalar_idx = i
break
if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type',
_logger.info('Not optimizing dot22 with inputs %s %s, as the'
' type of the scalar cannot be upcasted to the'
' matrix type',
node.inputs, [x.type for x in node.inputs])
return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype)
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx],
dtype=d.dtype), d.type.dtype)
assert not a.type.ndim
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# What about the other inputs to the original node that were
# neither part of the dot22 or this mul?
......@@ -1657,7 +1748,7 @@ def local_dot22_to_dot22scalar(node):
assert all((i in (dot22_idx, mul_idx))
for i in xrange(len(node.inputs)))
return [T.mul(m.owner.inputs[1-i],dot)]
return [T.mul(m.owner.inputs[1 - i], dot)]
elif m.owner and m.owner.op == T.mul:
_logger.info('Not optimizing dot22 with inputs %s %s %s %s. '
'we need to check in a recursive way in the mul if we can '
......@@ -1667,9 +1758,9 @@ def local_dot22_to_dot22scalar(node):
return False
scalar_idx = -1
for i,x in enumerate(node.inputs):
for i, x in enumerate(node.inputs):
if (i_scalar[i] is not None
and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
and (theano.scalar.upcast(x.type.dtype, d.type.dtype)
== d.type.dtype)):
scalar_idx = i
break
......@@ -1689,15 +1780,17 @@ def local_dot22_to_dot22scalar(node):
if len(o) == 0:
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else:
return [T.mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)]
return [T.mul(_dot22scalar(d.owner.inputs[0],
d.owner.inputs[1], a), *o)]
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar
#must happen after gemm as the gemm optimizer don't understant
#dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar',
EquilibriumOptimizer([local_dot22_to_dot22scalar ], max_use_ratio=5),
EquilibriumOptimizer([local_dot22_to_dot22scalar], max_use_ratio=5),
11, 'fast_run')
from opt import register_specialize, register_canonicalize
#from opt import register_specialize, register_canonicalize
#@register_specialize
@local_optimizer([])
def local_print_as_we_go_along(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论