提交 ce291230 authored 作者: Frederic Bastien's avatar Frederic Bastien

white space fix.

上级 3e933636
...@@ -4,8 +4,8 @@ import sys, traceback, logging, copy, os ...@@ -4,8 +4,8 @@ import sys, traceback, logging, copy, os
import numpy import numpy
import numpy.distutils import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, PatternSub, DestroyHandler, from theano.gof import (utils, Op, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer, SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer) InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer)
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
...@@ -17,7 +17,7 @@ import basic as T ...@@ -17,7 +17,7 @@ import basic as T
from theano.tensor.tsor_apply import Apply from theano.tensor.tsor_apply import Apply
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
from theano import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from theano.tensor.blas_headers import cblas_header_text, blas_header_text from theano.tensor.blas_headers import cblas_header_text, blas_header_text
...@@ -108,11 +108,11 @@ def default_blas_ldflags(): ...@@ -108,11 +108,11 @@ def default_blas_ldflags():
if all(not os.path.exists(dir) for dir in numpy.distutils.__config__.blas_opt_info['library_dirs']): if all(not os.path.exists(dir) for dir in numpy.distutils.__config__.blas_opt_info['library_dirs']):
return "-lblas" return "-lblas"
return ' '.join( return ' '.join(
#TODO: the Gemm op below should separate the -L and -l arguments into the two callbacks that CLinker uses for that stuff. #TODO: the Gemm op below should separate the -L and -l arguments into the two callbacks that CLinker uses for that stuff.
# for now, we just pass the whole ldflags as the -l options part. # for now, we just pass the whole ldflags as the -l options part.
['-L%s'%l for l in numpy.distutils.__config__.blas_opt_info['library_dirs']] + ['-L%s'%l for l in numpy.distutils.__config__.blas_opt_info['library_dirs']] +
['-l%s'%l for l in numpy.distutils.__config__.blas_opt_info['libraries']]) ['-l%s'%l for l in numpy.distutils.__config__.blas_opt_info['libraries']])
# ['-I%s'%l for l in numpy.distutils.__config__.blas_opt_info['include_dirs']]) # ['-I%s'%l for l in numpy.distutils.__config__.blas_opt_info['include_dirs']])
except KeyError: except KeyError:
return "-lblas" return "-lblas"
...@@ -124,7 +124,7 @@ AddConfigVar('blas.ldflags', ...@@ -124,7 +124,7 @@ AddConfigVar('blas.ldflags',
def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
"""Return a list of libraries against which an Op's object file should be """Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation. linked to benefit from a BLAS implementation.
Default: ['blas'], but configuration variable config.blas.ldflags overrides this. Default: ['blas'], but configuration variable config.blas.ldflags overrides this.
""" """
rval = [] rval = []
...@@ -139,7 +139,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): ...@@ -139,7 +139,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
found_dyn=True found_dyn=True
if not found_dyn and dirs: if not found_dyn and dirs:
warning("We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.") warning("We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.")
for t in config.blas.ldflags.split(): for t in config.blas.ldflags.split():
try: try:
t0, t1, t2 = t[0:3] t0, t1, t2 = t[0:3]
...@@ -162,7 +162,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): ...@@ -162,7 +162,7 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
class GemmRelated(Op): class GemmRelated(Op):
"""Base class for Gemm and Dot22 """Base class for Gemm and Dot22
This class provides a kind of templated gemm Op. This class provides a kind of templated gemm Op.
""" """
def __eq__(self, other): def __eq__(self, other):
...@@ -186,14 +186,14 @@ class GemmRelated(Op): ...@@ -186,14 +186,14 @@ class GemmRelated(Op):
""" """
return blas_header_text() + mod_str return blas_header_text() + mod_str
def c_headers(self): def c_headers(self):
# std.cout doesn't require the '%' symbol to print stuff... # std.cout doesn't require the '%' symbol to print stuff...
# so it works much better with python's string-substitution stuff. # so it works much better with python's string-substitution stuff.
return ['<iostream>', '<time.h>', '<sys/time.h>'] return ['<iostream>', '<time.h>', '<sys/time.h>']
def c_libraries(self): def c_libraries(self):
return ldflags() return ldflags()
# code_cache_version is built by subclasses from # code_cache_version is built by subclasses from
# build_gemm_version # build_gemm_version
def c_compile_args(self): def c_compile_args(self):
...@@ -201,10 +201,10 @@ class GemmRelated(Op): ...@@ -201,10 +201,10 @@ class GemmRelated(Op):
def c_lib_dirs(self): def c_lib_dirs(self):
return ldflags(libs=False, libs_dir=True) return ldflags(libs=False, libs_dir=True)
def c_header_dirs(self): def c_header_dirs(self):
return ldflags(libs=False, include_dir=True) return ldflags(libs=False, include_dir=True)
declare_NS = """ declare_NS = """
int unit = 0; int unit = 0;
...@@ -231,15 +231,15 @@ class GemmRelated(Op): ...@@ -231,15 +231,15 @@ class GemmRelated(Op):
if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;} if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
""" """
check_xyz_double_or_float = """ check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE) if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT)) && (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE) if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT)) && (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_zout)s->descr->type_num != PyArray_DOUBLE) if ((%(_zout)s->descr->type_num != PyArray_DOUBLE)
&& (%(_zout)s->descr->type_num != PyArray_FLOAT)) && (%(_zout)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
...@@ -262,21 +262,21 @@ class GemmRelated(Op): ...@@ -262,21 +262,21 @@ class GemmRelated(Op):
check_dims_strides = """ check_dims_strides = """
if (Nx[0] != Nz[0]) if (Nx[0] != Nz[0])
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld rows but z has %%ld rows", "Shape mismatch: x has %%ld rows but z has %%ld rows",
(long int)Nx[0], (long int)Nz[0]); (long int)Nx[0], (long int)Nz[0]);
%(fail)s; %(fail)s;
} }
if (Nx[1] != Ny[0]) if (Nx[1] != Ny[0])
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld cols but y has %%ld rows", "Shape mismatch: x has %%ld cols but y has %%ld rows",
(long int)Nx[1], (long int)Ny[0]); (long int)Nx[1], (long int)Ny[0]);
%(fail)s; %(fail)s;
} }
if (Ny[1] != Nz[1]) if (Ny[1] != Nz[1])
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: y has %%ld cols but z has %%ld cols", "Shape mismatch: y has %%ld cols but z has %%ld cols",
(long int)Ny[1], (long int)Nz[1]); (long int)Ny[1], (long int)Nz[1]);
%(fail)s; %(fail)s;
...@@ -413,11 +413,11 @@ class Gemm(GemmRelated): ...@@ -413,11 +413,11 @@ class Gemm(GemmRelated):
When a and b are scalars and x, y, and z are matrices, then When a and b are scalars and x, y, and z are matrices, then
gemm(z,a,x,y,b) gemm(z,a,x,y,b)
is similar to is similar to
b*z + a*dot(x,y) b*z + a*dot(x,y)
The difference between the two is that the top form is destructive on z, The difference between the two is that the top form is destructive on z,
whereas the bottom form is not. Gemm works in-place on the storage whereas the bottom form is not. Gemm works in-place on the storage
...@@ -450,7 +450,7 @@ class Gemm(GemmRelated): ...@@ -450,7 +450,7 @@ class Gemm(GemmRelated):
def __setstate__(self, dct): def __setstate__(self, dct):
inplace = dct.get('inplace', True) inplace = dct.get('inplace', True)
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else: else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
...@@ -577,7 +577,7 @@ class Gemm(GemmRelated): ...@@ -577,7 +577,7 @@ class Gemm(GemmRelated):
case_float_ab_constants = """ case_float_ab_constants = """
#define REAL float #define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT) float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0]) ? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]); : (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ? float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
...@@ -587,7 +587,7 @@ class Gemm(GemmRelated): ...@@ -587,7 +587,7 @@ class Gemm(GemmRelated):
""" """
case_double_ab_constants = """ case_double_ab_constants = """
#define REAL double #define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT) double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0]) ? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]); : (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ? double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
...@@ -618,14 +618,14 @@ pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace')) ...@@ -618,14 +618,14 @@ pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace')) pprint.assign(gemm_no_inplace, FunctionPrinter('gemm_no_inplace'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
if maxclients is not None: if maxclients is not None:
retval = (len(node.clients) <= maxclients) retval = (len(node.clients) <= maxclients)
else: else:
retval = True retval = True
return node.owner \ return node.owner \
and node.owner.op == op \ and node.owner.op == op \
and retval and retval
def _as_scalar(res): def _as_scalar(res):
...@@ -654,7 +654,7 @@ def _is_real_matrix(res): ...@@ -654,7 +654,7 @@ def _is_real_matrix(res):
def _is_real_vector(res): def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \ return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \ and res.type.ndim == 1 \
and res.type.broadcastable[0] == False and res.type.broadcastable[0] == False
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
...@@ -680,7 +680,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -680,7 +680,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
pass pass
if Mr.ndim == 2: if Mr.ndim == 2:
#print "RETURNING GEMV (case 2)" #print "RETURNING GEMV (case 2)"
if Mr.dtype == Ml.dtype: if Mr.dtype == Ml.dtype:
rval = [gemv_no_inplace(L, alpha, Mr.T, Ml, beta)] rval = [gemv_no_inplace(L, alpha, Mr.T, Ml, beta)]
assert L.type == rval[0].type, (L.type, rval[0].type) assert L.type == rval[0].type, (L.type, rval[0].type)
else: else:
...@@ -700,7 +700,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -700,7 +700,7 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
pass pass
return rval return rval
# this is False'd out because of inadequate testing. # this is False'd out because of inadequate testing.
# TODO see ticket #237 # TODO see ticket #237
if False and res_is_a(M, gemm_no_inplace, 1): if False and res_is_a(M, gemm_no_inplace, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(G, a, u, v, b))) #EXPRESSION: (beta * L) + (alpha * (gemm_no_inplace(G, a, u, v, b)))
...@@ -860,7 +860,7 @@ def _gemm_from_factored_list(lst): ...@@ -860,7 +860,7 @@ def _gemm_from_factored_list(lst):
s_j, M_j = lst[j] s_j, M_j = lst[j]
except: except:
continue continue
#print 'TRYING', (s_i, M_i, s_j, M_j) #print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j) gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j)
...@@ -874,7 +874,7 @@ def _gemm_from_factored_list(lst): ...@@ -874,7 +874,7 @@ def _gemm_from_factored_list(lst):
return s*M return s*M
assert len(gemm_of_sM_list) == 1 assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input) add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i,j)] for k, input in enumerate(lst) if k not in (i,j)]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: if len(add_inputs) > 1:
...@@ -1050,7 +1050,7 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run') ...@@ -1050,7 +1050,7 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# run before specialize (2.0) because specialize is basically a free-for-all that makes the # run before specialize (2.0) because specialize is basically a free-for-all that makes the
# graph crazy. # graph crazy.
blas_optdb.register('local_dot_to_dot22', blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5), EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run') 0, 'fast_run')
blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run') blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
...@@ -1058,9 +1058,9 @@ blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run') ...@@ -1058,9 +1058,9 @@ blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
# After destroyhandler is in but before we try to make elemwise things inplace # After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71) # Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt', optdb.register('InplaceBlasOpt',
EquilibriumOptimizer([local_inplace_gemm, local_inplace_gemv], failure_callback=EquilibriumOptimizer.warn_inplace, EquilibriumOptimizer([local_inplace_gemm, local_inplace_gemv], failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5), max_use_ratio=5),
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
class Dot22Scalar(GemmRelated): class Dot22Scalar(GemmRelated):
...@@ -1103,7 +1103,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1103,7 +1103,7 @@ class Dot22Scalar(GemmRelated):
""" """
case_float_ab_constants = """ case_float_ab_constants = """
#define REAL float #define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT) float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0]) ? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]); : (REAL)(((double*)%(_a)s->data)[0]);
#undef REAL #undef REAL
...@@ -1111,7 +1111,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1111,7 +1111,7 @@ class Dot22Scalar(GemmRelated):
""" """
case_double_ab_constants = """ case_double_ab_constants = """
#define REAL double #define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT) double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0]) ? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]); : (REAL)(((double*)%(_a)s->data)[0]);
#undef REAL #undef REAL
...@@ -1138,7 +1138,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1138,7 +1138,7 @@ def local_dot22_to_dot22scalar(node):
.. note: .. note:
We execute this optimizer after the gemm optimizer. This allow to give more priority to gemm that give more speed up then this optimizer, but allow the gemm optimizer to ignore this op. We execute this optimizer after the gemm optimizer. This allow to give more priority to gemm that give more speed up then this optimizer, but allow the gemm optimizer to ignore this op.
TODO: support when we can reorder the mul to generate a dot22scalar or fix the canonizer to merge them(1 mul with multiple inputs) TODO: support when we can reorder the mul to generate a dot22scalar or fix the canonizer to merge them(1 mul with multiple inputs)
""" """
if node.op != T.mul: if node.op != T.mul:
...@@ -1154,7 +1154,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1154,7 +1154,7 @@ def local_dot22_to_dot22scalar(node):
#no scalar in input and no multiplication #no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph. #if their was a multiplication we couls reorder the graph by the associativity of the graph.
return False return False
if not any(i_scalar): if not any(i_scalar):
#maybe we can reorder the graph as this mul have a mul in input. #maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together. #The canonizer should have merged those mul together.
...@@ -1207,4 +1207,3 @@ from opt import register_specialize, register_canonicalize ...@@ -1207,4 +1207,3 @@ from opt import register_specialize, register_canonicalize
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
debugprint(node) debugprint(node)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论