提交 ce291230 authored 作者: Frederic Bastien's avatar Frederic Bastien

white space fix.

上级 3e933636
......@@ -4,8 +4,8 @@ import sys, traceback, logging, copy, os
import numpy
import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
from theano.gof import (utils, Op, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer)
from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
......@@ -17,7 +17,7 @@ import basic as T
from theano.tensor.tsor_apply import Apply
#NB: this clobbers the builtin 'compile' symbol
from theano import compile #to register the optimizer built by this file
from theano import compile #to register the optimizer built by this file
from theano.tensor.blas_headers import cblas_header_text, blas_header_text
......@@ -108,11 +108,11 @@ def default_blas_ldflags():
if all(not os.path.exists(dir) for dir in numpy.distutils.__config__.blas_opt_info['library_dirs']):
return "-lblas"
return ' '.join(
#TODO: the Gemm op below should separate the -L and -l arguments into the two callbacks that CLinker uses for that stuff.
#TODO: the Gemm op below should separate the -L and -l arguments into the two callbacks that CLinker uses for that stuff.
# for now, we just pass the whole ldflags as the -l options part.
['-L%s'%l for l in numpy.distutils.__config__.blas_opt_info['library_dirs']] +
['-l%s'%l for l in numpy.distutils.__config__.blas_opt_info['libraries']])
# ['-I%s'%l for l in numpy.distutils.__config__.blas_opt_info['include_dirs']])
['-L%s'%l for l in numpy.distutils.__config__.blas_opt_info['library_dirs']] +
['-l%s'%l for l in numpy.distutils.__config__.blas_opt_info['libraries']])
# ['-I%s'%l for l in numpy.distutils.__config__.blas_opt_info['include_dirs']])
except KeyError:
return "-lblas"
......@@ -124,7 +124,7 @@ AddConfigVar('blas.ldflags',
def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
"""Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation.
Default: ['blas'], but configuration variable config.blas.ldflags overrides this.
"""
rval = []
......@@ -139,7 +139,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
found_dyn=True
if not found_dyn and dirs:
warning("We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.")
for t in config.blas.ldflags.split():
try:
t0, t1, t2 = t[0:3]
......@@ -162,7 +162,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
class GemmRelated(Op):
"""Base class for Gemm and Dot22
This class provides a kind of templated gemm Op.
"""
def __eq__(self, other):
......@@ -186,14 +186,14 @@ class GemmRelated(Op):
"""
return blas_header_text() + mod_str
def c_headers(self):
# std.cout doesn't require the '%' symbol to print stuff...
# std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff.
return ['<iostream>', '<time.h>', '<sys/time.h>']
return ['<iostream>', '<time.h>', '<sys/time.h>']
def c_libraries(self):
return ldflags()
# code_cache_version is built by subclasses from
# code_cache_version is built by subclasses from
# build_gemm_version
def c_compile_args(self):
......@@ -201,10 +201,10 @@ class GemmRelated(Op):
def c_lib_dirs(self):
return ldflags(libs=False, libs_dir=True)
def c_header_dirs(self):
return ldflags(libs=False, include_dir=True)
declare_NS = """
int unit = 0;
......@@ -231,15 +231,15 @@ class GemmRelated(Op):
if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
"""
check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_zout)s->descr->type_num != PyArray_DOUBLE)
if ((%(_zout)s->descr->type_num != PyArray_DOUBLE)
&& (%(_zout)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
......@@ -262,21 +262,21 @@ class GemmRelated(Op):
check_dims_strides = """
if (Nx[0] != Nz[0])
{
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld rows but z has %%ld rows",
(long int)Nx[0], (long int)Nz[0]);
%(fail)s;
}
if (Nx[1] != Ny[0])
{
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld cols but y has %%ld rows",
(long int)Nx[1], (long int)Ny[0]);
%(fail)s;
}
if (Ny[1] != Nz[1])
{
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"Shape mismatch: y has %%ld cols but z has %%ld cols",
(long int)Ny[1], (long int)Nz[1]);
%(fail)s;
......@@ -413,11 +413,11 @@ class Gemm(GemmRelated):
When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b)
gemm(z,a,x,y,b)
is similar to
is similar to
b*z + a*dot(x,y)
b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage
......@@ -450,7 +450,7 @@ class Gemm(GemmRelated):
def __setstate__(self, dct):
inplace = dct.get('inplace', True)
if inplace:
self.destroy_map = {0: [0]}
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
......@@ -577,7 +577,7 @@ class Gemm(GemmRelated):
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
......@@ -587,7 +587,7 @@ class Gemm(GemmRelated):
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
......@@ -618,14 +618,14 @@ pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace'))
def res_is_a(node, op, maxclients=None):
if maxclients is not None:
retval = (len(node.clients) <= maxclients)
else:
retval = True
if maxclients is not None:
retval = (len(node.clients) <= maxclients)
else:
retval = True
return node.owner \
and node.owner.op == op \
and retval
return node.owner \
and node.owner.op == op \
and retval
def _as_scalar(res):
......@@ -654,7 +654,7 @@ def _is_real_matrix(res):
def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \
and res.type.broadcastable[0] == False
and res.type.broadcastable[0] == False
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
......@@ -680,7 +680,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
pass
if Mr.ndim == 2:
#print "RETURNING GEMV (case 2)"
if Mr.dtype == Ml.dtype:
if Mr.dtype == Ml.dtype:
rval = [gemv_no_inplace(L, alpha, Mr.T, Ml, beta)]
assert L.type == rval[0].type, (L.type, rval[0].type)
else:
......@@ -700,7 +700,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
pass
return rval
# this is False'd out because of inadequate testing.
# this is False'd out because of inadequate testing.
# TODO see ticket #237
if False and res_is_a(M, gemm_no_inplace, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(G, a, u, v, b)))
......@@ -860,7 +860,7 @@ def _gemm_from_factored_list(lst):
s_j, M_j = lst[j]
except:
continue
#print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j)
......@@ -874,7 +874,7 @@ def _gemm_from_factored_list(lst):
return s*M
assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input)
add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i,j)]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
......@@ -1050,7 +1050,7 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# run before specialize (2.0) because specialize is basically a free-for-all that makes the
# graph crazy.
blas_optdb.register('local_dot_to_dot22',
blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run')
blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
......@@ -1058,9 +1058,9 @@ blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
# After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt',
optdb.register('InplaceBlasOpt',
EquilibriumOptimizer([local_inplace_gemm, local_inplace_gemv], failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5),
max_use_ratio=5),
70.0, 'fast_run', 'inplace')
class Dot22Scalar(GemmRelated):
......@@ -1103,7 +1103,7 @@ class Dot22Scalar(GemmRelated):
"""
case_float_ab_constants = """
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
#undef REAL
......@@ -1111,7 +1111,7 @@ class Dot22Scalar(GemmRelated):
"""
case_double_ab_constants = """
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
#undef REAL
......@@ -1138,7 +1138,7 @@ def local_dot22_to_dot22scalar(node):
.. note:
We execute this optimizer after the gemm optimizer. This allow to give more priority to gemm that give more speed up then this optimizer, but allow the gemm optimizer to ignore this op.
TODO: support when we can reorder the mul to generate a dot22scalar or fix the canonizer to merge them(1 mul with multiple inputs)
"""
if node.op != T.mul:
......@@ -1154,7 +1154,7 @@ def local_dot22_to_dot22scalar(node):
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False
if not any(i_scalar):
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
......@@ -1207,4 +1207,3 @@ from opt import register_specialize, register_canonicalize
def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add):
debugprint(node)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论